Преглед изворни кода

Clean the convert to torch script

oobabooga пре 3 година
родитељ
комит
898e12058e
1 измењених фајлова са 15 додато и 26 уклоњено
  1. 15 26
      convert-to-torch.py

+ 15 - 26
convert-to-torch.py

@@ -1,38 +1,27 @@
 '''
 Converts a transformers model to .pt, which is faster to load.
  
-Run with python convert.py /path/to/model/
-Make sure to write /path/to/model/ with a trailing / and not
-/path/to/model
+Example:
+python convert.py models/opt-1.3b
  
 Output will be written to torch-dumps/name-of-the-model.pt
 '''
  
-from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, AutoTokenizer, set_seed
-from transformers import GPT2Tokenizer, GPT2Model, T5Tokenizer, T5ForConditionalGeneration
+from transformers import AutoModelForCausalLM, T5ForConditionalGeneration
 import torch
-import sys
 from sys import argv
-import time
-import glob
-import psutil
  
-print(f"torch-dumps/{argv[1].split('/')[-2]}.pt")
- 
-if argv[1].endswith('pt'):
-    model = OPTForCausalLM.from_pretrained(argv[1], device_map="auto")
-    torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
-elif 'galactica' in argv[1].lower():
-    model = OPTForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
-    #model = OPTForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, load_in_8bit=True)
-    torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
-elif 'flan-t5' in argv[1].lower():
-    model = T5ForConditionalGeneration.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
-    torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
+path = argv[1]
+if path[-1] != '/':
+    path = path+'/'
+model_name = path.split('/')[-2]
+
+print(f"Loading {model_name}...")
+if model_name in ['flan-t5', 't5-large']:
+    model = T5ForConditionalGeneration.from_pretrained(path).cuda()
 else:
-    print("Loading the model")
-    model = AutoModelForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
-    print("Model loaded")
-    #model = AutoModelForCausalLM.from_pretrained(argv[1], device_map='auto', load_in_8bit=True)
-    torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
+    model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
+print("Model loaded.")
 
+print(f"Saving to torch-dumps/{model_name}.pt")
+torch.save(model, f"torch-dumps/{model_name}.pt")