Ver código fonte

Improve progress bar visual style

This commit reverts the performance improvements of the previous commit for for improved visual style of multithreaded progress bars. The style of the progress bar has been modified to take up the same amount of size to align them.
Nikita Skakun 2 anos atrás
pai
commit
ff515ec2fe
1 arquivos alterados com 10 adições e 21 exclusões
  1. 10 21
      download-model.py

+ 10 - 21
download-model.py

@@ -16,23 +16,17 @@ from pathlib import Path
 
 
 import requests
 import requests
 import tqdm
 import tqdm
+from tqdm.contrib.concurrent import thread_map
 
 
-def get_file(args):
-    url = args[0]
-    output_folder = args[1]
-    idx = args[2]
-    tot = args[3]
-
-    print(f"Downloading file {idx} of {tot}...")
+def get_file(url, output_folder):
     r = requests.get(url, stream=True)
     r = requests.get(url, stream=True)
-    with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
+    with open(output_folder / Path(url.rsplit('/', 1)[1]), 'wb') as f:
         total_size = int(r.headers.get('content-length', 0))
         total_size = int(r.headers.get('content-length', 0))
         block_size = 1024
         block_size = 1024
-        t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True)
-        for data in r.iter_content(block_size):
-            t.update(len(data))
-            f.write(data)
-        t.close()
+        with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
+            for data in r.iter_content(block_size):
+                t.update(len(data))
+                f.write(data)
 
 
 def sanitize_branch_name(branch_name):
 def sanitize_branch_name(branch_name):
     pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
     pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
@@ -143,13 +137,8 @@ def get_download_links_from_huggingface(model, branch):
 
 
     return links, is_lora
     return links, is_lora
 
 
-def download_files(file_list, output_folder, num_processes=8):
-    with multiprocessing.Pool(processes=num_processes) as pool:
-        args = [(url, output_folder, idx+1, len(file_list)) for idx, url in enumerate(file_list)]
-        for _ in tqdm.tqdm(pool.imap_unordered(get_file, args), total=len(args)):
-            pass
-        pool.close()
-        pool.join()
+def download_files(file_list, output_folder, num_threads=8):
+    thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, verbose=False)
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser = argparse.ArgumentParser()
@@ -187,4 +176,4 @@ if __name__ == '__main__':
 
 
     # Downloading the files
     # Downloading the files
     print(f"Downloading the model to {output_folder}")
     print(f"Downloading the model to {output_folder}")
-    download_files(links, output_folder, num_processes=args.threads)
+    download_files(links, output_folder, args.threads)