소스 검색

Better defaults while loading models

oobabooga 3 년 전
부모
커밋
ee650343bc
2개의 변경된 파일12개의 추가작업 그리고 5개의 파일을 삭제
  1. 12 5
      server.py
  2. 0 0
      torch-dumps/place-your-pt-models-here.txt

+ 12 - 5
server.py

@@ -12,8 +12,8 @@ from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2
 #model_name = 'gpt-j-6B-float16'
 #model_name = "opt-6.7b"
 #model_name = 'opt-13b'
-#model_name = "gpt4chan_model_float16"
-model_name = 'galactica-6.7b'
+model_name = "gpt4chan_model_float16"
+#model_name = 'galactica-6.7b'
 #model_name = 'gpt-neox-20b'
 #model_name = 'flan-t5'
 #model_name = 'OPT-13B-Erebus'
@@ -24,17 +24,24 @@ def load_model(model_name):
     print(f"Loading {model_name}...")
     t0 = time.time()
 
+    # Loading the model
     if os.path.exists(f"torch-dumps/{model_name}.pt"):
         print("Loading in .pt format...")
         model = torch.load(f"torch-dumps/{model_name}.pt").cuda()
-    elif model_name.lower().startswith(('gpt-neo', 'opt-')):
-        model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", device_map='auto', load_in_8bit=True)
+    elif model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')):
+        if any(size in model_name for size in ('13b', '20b', '30b')):
+            model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", device_map='auto', load_in_8bit=True)
+        else:
+            model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
     elif model_name in ['gpt-j-6B']:
         model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
     elif model_name in ['flan-t5', 't5-large']:
         model = T5ForConditionalGeneration.from_pretrained(f"models/{model_name}").cuda()
+    else:
+        model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
 
-    if model_name in ['gpt4chan_model_float16']:
+    # Loading the tokenizer
+    if model_name.startswith('gpt4chan'):
         tokenizer = AutoTokenizer.from_pretrained("models/gpt-j-6B/")
     elif model_name in ['flan-t5']:
         tokenizer = T5Tokenizer.from_pretrained(f"models/{model_name}/")

+ 0 - 0
torch-dumps/place-your-pt-models-here.txt