Explorar el Código

Give some default options in the download script

oobabooga hace 3 años
padre
commit
fd8070b960
Se han modificado 1 ficheros con 59 adiciones y 12 borrados
  1. 59 12
      download-model.py

+ 59 - 12
download-model.py

@@ -16,7 +16,7 @@ import tqdm
 from bs4 import BeautifulSoup
 
 parser = argparse.ArgumentParser()
-parser.add_argument('MODEL', type=str)
+parser.add_argument('MODEL', type=str, default=None, nargs='?')
 parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
 parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
 parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
@@ -46,20 +46,67 @@ def sanitize_branch_name(branch_name):
     else:
         raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
 
+def select_model_from_default_options():
+    models = {
+        "Pygmalion 6B original": ("PygmalionAI", "pygmalion-6b", "b8344bb4eb76a437797ad3b19420a13922aaabe1"),
+        "Pygmalion 6B main": ("PygmalionAI", "pygmalion-6b", "main"),
+        "Pygmalion 6B dev": ("PygmalionAI", "pygmalion-6b", "dev"),
+        "Pygmalion 2.7B": ("PygmalionAI", "pygmalion-2.7b", "main"),
+        "Pygmalion 1.3B": ("PygmalionAI", "pygmalion-1.3b", "main"),
+        "Pygmalion 350m": ("PygmalionAI", "pygmalion-350m", "main"),
+        "OPT 6.7b": ("facebook", "opt-6.7b", "main"),
+        "OPT 2.7b": ("facebook", "opt-2.7b", "main"),
+        "OPT 1.3b": ("facebook", "opt-1.3b", "main"),
+        "OPT 350m": ("facebook", "opt-350m", "main"),
+    }
+    choices = {}
+
+    print("Select the model that you want to download:\n")
+    for i,name in enumerate(models):
+        char = chr(ord('A')+i)
+        choices[char] = name
+        print(f"{char}) {name}")
+    char = chr(ord('A')+len(models))
+    print(f"{char}) None of the above")
+
+    print()
+    print("Input> ", end='')
+    choice = input()[0]
+    if choice == char:
+        print("""\nThen type the name of your desired Hugging Face model in the format organization/name.
+
+Examples:
+PygmalionAI/pygmalion-6b
+facebook/opt-1.3b
+""")
+
+        print("Input> ", end='')
+        model = input()
+        branch = "main"
+    else:
+        arr = models[choices[choice]]
+        model = f"{arr[0]}/{arr[1]}"
+        branch = arr[2]
+
+    return model, branch
+
 if __name__ == '__main__':
     model = args.MODEL
-    if model[-1] == '/':
-        model = model[:-1]
-        branch = args.branch
-    if args.branch is None:
-        branch = 'main'
+    branch = args.branch
+    if model is None:
+        model, branch = select_model_from_default_options()
     else:
-        try:
-            branch_name = args.branch
-            branch = sanitize_branch_name(branch_name)
-        except ValueError as err_branch:
-            print(f"Error: {err_branch}")
-            sys.exit()
+        if model[-1] == '/':
+            model = model[:-1]
+            branch = args.branch
+        if branch is None:
+            branch = "main"
+        else:
+            try:
+                branch = sanitize_branch_name(branch)
+            except ValueError as err_branch:
+                print(f"Error: {err_branch}")
+                sys.exit()
     url = f'https://huggingface.co/{model}/tree/{branch}'
     if branch != 'main':
         output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')