oobabooga пре 3 година
родитељ
комит
18ae08ef91
1 измењених фајлова са 2 додато и 5 уклоњено
  1. 2 5
      convert-to-torch.py

+ 2 - 5
convert-to-torch.py

@@ -7,7 +7,7 @@ python convert-to-torch.py models/opt-1.3b
 The output will be written to torch-dumps/name-of-the-model.pt
 '''
  
-from transformers import AutoModelForCausalLM, T5ForConditionalGeneration
+from transformers import AutoModelForCausalLM
 import torch
 from sys import argv
 from pathlib import Path
@@ -16,10 +16,7 @@ path = Path(argv[1])
 model_name = path.name
 
 print(f"Loading {model_name}...")
-if model_name in ['flan-t5', 't5-large']:
-    model = T5ForConditionalGeneration.from_pretrained(path).cuda()
-else:
-    model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
+model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
 print("Model loaded.")
 
 print(f"Saving to torch-dumps/{model_name}.pt")