oobabooga 3 лет назад
Родитель
Сommit
f0013ac8e9
1 измененных файлов с 2 добавлено и 7 удалено
  1. 2 7
      server.py

+ 2 - 7
server.py

@@ -27,11 +27,6 @@ def load_model(model_name):
     print(f"Loading {model_name}...")
     print(f"Loading {model_name}...")
     t0 = time.time()
     t0 = time.time()
 
 
-    if args.cpu:
-        dtype = torch.float32
-    else:
-        dtype = torch.float16
-
     # Loading the model
     # Loading the model
     if not args.cpu and Path(f"torch-dumps/{model_name}.pt").exists():
     if not args.cpu and Path(f"torch-dumps/{model_name}.pt").exists():
         print("Loading in .pt format...")
         print("Loading in .pt format...")
@@ -45,9 +40,9 @@ def load_model(model_name):
             model = T5ForConditionalGeneration.from_pretrained(Path(f"models/{model_name}")).cuda()
             model = T5ForConditionalGeneration.from_pretrained(Path(f"models/{model_name}")).cuda()
     else:
     else:
         if args.cpu:
         if args.cpu:
-            model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=dtype)
+            model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.float32)
         else:
         else:
-            model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=dtype).cuda()
+            model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
 
 
     # Loading the tokenizer
     # Loading the tokenizer
     if model_name.lower().startswith('gpt4chan') and Path(f"models/gpt-j-6B/").exists():
     if model_name.lower().startswith('gpt4chan') and Path(f"models/gpt-j-6B/").exists():