Parcourir la source

Add --threads option to the download script

oobabooga il y a 3 ans
Parent
commit
9215e281ba
1 fichiers modifiés avec 7 ajouts et 3 suppressions
  1. 7 3
      download-model.py

+ 7 - 3
download-model.py

@@ -18,12 +18,16 @@ import re
 parser = argparse.ArgumentParser()
 parser.add_argument('MODEL', type=str)
 parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
+parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
 args = parser.parse_args()
 
 def get_file(args):
     url = args[0]
     output_folder = args[1]
+    idx = args[2]
+    tot = args[3]
 
+    print(f"Downloading file {idx} of {tot}...")
     r = requests.get(url, stream=True)
     with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
         total_size = int(r.headers.get('content-length', 0))
@@ -77,8 +81,8 @@ if __name__ == '__main__':
                 downloads.append(f'https://huggingface.co/{href}')
 
     # Downloading the files
-    print(f"Downloading the model to {output_folder}...")
-    pool = multiprocessing.Pool(processes=4)
-    results = pool.map(get_file, [[downloads[i], output_folder] for i in range(len(downloads))])
+    print(f"Downloading the model to {output_folder}")
+    pool = multiprocessing.Pool(processes=args.threads)
+    results = pool.map(get_file, [[downloads[i], output_folder, i+1, len(downloads)] for i in range(len(downloads))])
     pool.close()
     pool.join()