Procházet zdrojové kódy

Add --branch option to the model download script

81300 před 3 roky
rodič
revize
fffd49e64e
1 změnil soubory, kde provedl 31 přidání a 3 odebrání
  1. 31 3
      download-model.py

+ 31 - 3
download-model.py

@@ -10,8 +10,16 @@ import requests
 from bs4 import BeautifulSoup 
 import multiprocessing
 import tqdm
+import sys
 from sys import argv
+import argparse
 from pathlib import Path
+import re
+
+parser = argparse.ArgumentParser()
+parser.add_argument('MODEL', type=str)
+parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
+args = parser.parse_args()
 
 def get_file(args):
     url = args[0]
@@ -27,12 +35,32 @@ def get_file(args):
             f.write(data)
         t.close()
 
+def sanitize_branch_name(branch_name):
+    pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
+    if pattern.match(branch_name):
+        return branch_name
+    else:
+        raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
+
 if __name__ == '__main__':
     model = argv[1]
     if model[-1] == '/':
         model = model[:-1]
-    url = f'https://huggingface.co/{model}/tree/main'
-    output_folder = Path("models") / model.split('/')[-1]
+        branch = args.branch
+    if args.branch is None:
+        branch = 'main'
+    else:
+        try:
+            branch_name = args.branch
+            branch = sanitize_branch_name(branch_name)
+        except ValueError as err_branch:
+            print(f"Error: {err_branch}")
+            sys.exit()
+    url = f'https://huggingface.co/{model}/tree/{branch}'
+    if branch != 'main':
+        output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
+    else:
+        output_folder = Path("models") / model.split('/')[-1]
     if not output_folder.exists():
         output_folder.mkdir()
 
@@ -43,7 +71,7 @@ if __name__ == '__main__':
     downloads = []
     for link in links:
         href = link.get('href')[1:]
-        if href.startswith(f'{model}/resolve/main'):
+        if href.startswith(f'{model}/resolve/{branch}'):
             if href.endswith(('.json', '.txt')) or (href.endswith('.bin') and 'pytorch_model' in href):
                 downloads.append(f'https://huggingface.co/{href}')