|
@@ -102,6 +102,10 @@ def load_model(model_name):
|
|
|
if path_to_model.name.lower().startswith('llama-30b'):
|
|
if path_to_model.name.lower().startswith('llama-30b'):
|
|
|
pt_model = 'llama-30b-4bit.pt'
|
|
pt_model = 'llama-30b-4bit.pt'
|
|
|
|
|
|
|
|
|
|
+ if not Path(f"models/{pt_model}").exists():
|
|
|
|
|
+ print(f"Could not find models/{pt_model}, exiting...")
|
|
|
|
|
+ exit()
|
|
|
|
|
+
|
|
|
model = load_quant(path_to_model, Path(f"models/{pt_model}"), 4)
|
|
model = load_quant(path_to_model, Path(f"models/{pt_model}"), 4)
|
|
|
model = model.to(torch.device('cuda:0'))
|
|
model = model.to(torch.device('cuda:0'))
|
|
|
|
|
|
|
@@ -178,4 +182,3 @@ def load_soft_prompt(name):
|
|
|
shared.soft_prompt_tensor = tensor
|
|
shared.soft_prompt_tensor = tensor
|
|
|
|
|
|
|
|
return name
|
|
return name
|
|
|
-
|
|
|