Explorar o código

Merge pull request #618 from nikita-skakun/optimize-download-model

Improve download-model.py progress bar with multiple threads
oobabooga %!s(int64=2) %!d(string=hai) anos
pai
achega
9104164297
Modificáronse 1 ficheiros con 11 adicións e 18 borrados
  1. 11 18
      download-model.py

+ 11 - 18
download-model.py

@@ -10,13 +10,13 @@ import argparse
 import base64
 import datetime
 import json
-import multiprocessing
 import re
 import sys
 from pathlib import Path
 
 import requests
 import tqdm
+from tqdm.contrib.concurrent import thread_map
 
 parser = argparse.ArgumentParser()
 parser.add_argument('MODEL', type=str, default=None, nargs='?')
@@ -26,22 +26,15 @@ parser.add_argument('--text-only', action='store_true', help='Only download text
 parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
 args = parser.parse_args()
 
-def get_file(args):
-    url = args[0]
-    output_folder = args[1]
-    idx = args[2]
-    tot = args[3]
-
-    print(f"Downloading file {idx} of {tot}...")
+def get_file(url, output_folder):
     r = requests.get(url, stream=True)
-    with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
+    with open(output_folder / Path(url.rsplit('/', 1)[1]), 'wb') as f:
         total_size = int(r.headers.get('content-length', 0))
         block_size = 1024
-        t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True)
-        for data in r.iter_content(block_size):
-            t.update(len(data))
-            f.write(data)
-        t.close()
+        with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
+            for data in r.iter_content(block_size):
+                t.update(len(data))
+                f.write(data)
 
 def sanitize_branch_name(branch_name):
     pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
@@ -152,6 +145,9 @@ def get_download_links_from_huggingface(model, branch):
 
     return links, is_lora
 
+def download_files(file_list, output_folder, num_threads=8):
+    thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, verbose=False)
+
 if __name__ == '__main__':
     model = args.MODEL
     branch = args.branch
@@ -192,7 +188,4 @@ if __name__ == '__main__':
 
     # Downloading the files
     print(f"Downloading the model to {output_folder}")
-    pool = multiprocessing.Pool(processes=args.threads)
-    results = pool.map(get_file, [[links[i], output_folder, i+1, len(links)] for i in range(len(links))])
-    pool.close()
-    pool.join()
+    download_files(links, output_folder, args.threads)