Просмотр исходного кода

Merge pull request #210 from rohvani/pt-path-changes

Add llama-65b-4bit.pt support
oobabooga 2 лет назад
Родитель
Сommit
e01da4097c
1 измененных файлов с 14 добавлено и 7 удалено
  1. 14 7
      modules/models.py

+ 14 - 7
modules/models.py

@@ -97,16 +97,23 @@ def load_model(model_name):
         pt_model = ''
         if path_to_model.name.lower().startswith('llama-7b'):
             pt_model = 'llama-7b-4bit.pt'
-        if path_to_model.name.lower().startswith('llama-13b'):
+        elif path_to_model.name.lower().startswith('llama-13b'):
             pt_model = 'llama-13b-4bit.pt'
-        if path_to_model.name.lower().startswith('llama-30b'):
+        elif path_to_model.name.lower().startswith('llama-30b'):
             pt_model = 'llama-30b-4bit.pt'
+        elif path_to_model.name.lower().startswith('llama-65b'):
+            pt_model = 'llama-65b-4bit.pt'
+        else:
+            pt_model = f'{model_name}-4bit.pt'
 
-        if not Path(f"models/{pt_model}").exists():
-            print(f"Could not find models/{pt_model}, exiting...")
-            exit()
-        elif pt_model == '':
-            print(f"Could not find the .pt model for {model_name}, exiting...")
+        # Try to find the .pt both in models/ and in the subfolder
+        pt_path = None
+        for path in [Path(p) for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
+            if path.exists():
+                pt_path = path
+
+        if not pt_path:
+            print(f"Could not find {pt_model}, exiting...")
             exit()
 
         model = load_quant(path_to_model, Path(f"models/{pt_model}"), 4)