Selaa lähdekoodia

allow quantized model to be loaded from model dir (#760)

catalpaaa 2 vuotta sitten
vanhempi
commit
4ab679480e
2 muutettua tiedostoa jossa 5 lisäystä ja 5 poistoa
  1. 3 3
      modules/GPTQ_loader.py
  2. 2 2
      modules/models.py

+ 3 - 3
modules/GPTQ_loader.py

@@ -74,7 +74,7 @@ def load_quantized(model_name):
         exit()
 
     # Now we are going to try to locate the quantized model file.
-    path_to_model = Path(f'models/{model_name}')
+    path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
     found_pts = list(path_to_model.glob("*.pt"))
     found_safetensors = list(path_to_model.glob("*.safetensors"))
     pt_path = None
@@ -95,8 +95,8 @@ def load_quantized(model_name):
         else:
             pt_model = f'{model_name}-{shared.args.wbits}bit'
 
-        # Try to find the .safetensors or .pt both in models/ and in the subfolder
-        for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
+        # Try to find the .safetensors or .pt both in the model dir and in the subfolder
+        for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
             if path.exists():
                 print(f"Found {path}")
                 pt_path = path

+ 2 - 2
modules/models.py

@@ -42,7 +42,7 @@ def load_model(model_name):
     t0 = time.time()
 
     shared.is_RWKV = 'rwkv-' in model_name.lower()
-    shared.is_llamacpp = len(list(Path(f'models/{model_name}').glob('ggml*.bin'))) > 0
+    shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0
 
     # Default settings
     if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
@@ -105,7 +105,7 @@ def load_model(model_name):
     elif shared.is_llamacpp:
         from modules.llamacpp_model import LlamaCppModel
 
-        model_file = list(Path(f'models/{model_name}').glob('ggml*.bin'))[0]
+        model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))[0]
         print(f"llama.cpp weights detected: {model_file}\n")
 
         model, tokenizer = LlamaCppModel.from_pretrained(model_file)