Explorar el Código

Refactor model loading function

oobabooga hace 3 años
padre
commit
00a12889e9
Se han modificado 2 ficheros con 10 adiciones y 11 borrados
  1. 10 11
      server.py
  2. 0 0
      torch-dumps/place-your-pt-models-here.txt

+ 10 - 11
server.py

@@ -36,15 +36,18 @@ def load_model(model_name):
     if not args.cpu and Path(f"torch-dumps/{model_name}.pt").exists():
         print("Loading in .pt format...")
         model = torch.load(Path(f"torch-dumps/{model_name}.pt"))
-    elif model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')):
-        if any(size in model_name.lower() for size in ('13b', '20b', '30b')):
-            model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True)
-        else:
-            model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=dtype)
+    elif model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')) and any(size in model_name.lower() for size in ('13b', '20b', '30b')):
+        model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True)
     elif model_name in ['flan-t5', 't5-large']:
-        model = T5ForConditionalGeneration.from_pretrained(Path(f"models/{model_name}"))
+        if args.cpu:
+            model = T5ForConditionalGeneration.from_pretrained(Path(f"models/{model_name}"))
+        else:
+            model = T5ForConditionalGeneration.from_pretrained(Path(f"models/{model_name}")).cuda()
     else:
-        model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=dtype)
+        if args.cpu:
+            model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=dtype)
+        else:
+            model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=dtype).cuda()
 
     # Loading the tokenizer
     if model_name.lower().startswith('gpt4chan') and Path(f"models/gpt-j-6B/").exists():
@@ -54,10 +57,6 @@ def load_model(model_name):
     else:
         tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{model_name}/"))
 
-    # Sending to the GPU
-    if not (args.cpu or any(size in model_name.lower() for size in ('13b', '20b', '30b'))):
-        model = model.cuda()
-
     print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
     return model, tokenizer
 

+ 0 - 0
torch-dumps/place-your-pt-models-here.txt