소스 검색

Refactor download process to use multiprocessing

The previous implementation used threads to download files in parallel, which could lead to performance issues due to the Global Interpreter Lock (GIL).
This commit refactors the download process to use multiprocessing instead,
which allows for true parallelism across multiple CPUs.
This results in significantly faster downloads, particularly for large models.
Nikita Skakun 2 년 전
부모
커밋
4d8e101006
1개의 변경된 파일16개의 추가작업 그리고 11개의 파일을 삭제
  1. 16 11
      download-model.py

+ 16 - 11
download-model.py

@@ -17,13 +17,6 @@ from pathlib import Path
 import requests
 import tqdm
 
-parser = argparse.ArgumentParser()
-parser.add_argument('MODEL', type=str, default=None, nargs='?')
-parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
-parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
-parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
-args = parser.parse_args()
-
 def get_file(args):
     url = args[0]
     output_folder = args[1]
@@ -150,7 +143,22 @@ def get_download_links_from_huggingface(model, branch):
 
     return links, is_lora
 
+def download_files(file_list, output_folder, num_processes=8):
+    with multiprocessing.Pool(processes=num_processes) as pool:
+        args = [(url, output_folder, idx+1, len(file_list)) for idx, url in enumerate(file_list)]
+        for _ in tqdm.tqdm(pool.imap_unordered(get_file, args), total=len(args)):
+            pass
+        pool.close()
+        pool.join()
+
 if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('MODEL', type=str, default=None, nargs='?')
+    parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
+    parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
+    parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
+    args = parser.parse_args()
+
     model = args.MODEL
     branch = args.branch
     if model is None:
@@ -179,7 +187,4 @@ if __name__ == '__main__':
 
     # Downloading the files
     print(f"Downloading the model to {output_folder}")
-    pool = multiprocessing.Pool(processes=args.threads)
-    results = pool.map(get_file, [[links[i], output_folder, i+1, len(links)] for i in range(len(links))])
-    pool.close()
-    pool.join()
+    download_files(links, output_folder, num_processes=args.threads)