|
@@ -0,0 +1,38 @@
|
|
|
|
|
+'''
|
|
|
|
|
+Converts a transformers model to .pt, which is faster to load.
|
|
|
|
|
+
|
|
|
|
|
+Run with python convert.py /path/to/model/
|
|
|
|
|
+Make sure to write /path/to/model/ with a trailing / and not
|
|
|
|
|
+/path/to/model
|
|
|
|
|
+
|
|
|
|
|
+Output will be written to torch-dumps/name-of-the-model.pt
|
|
|
|
|
+'''
|
|
|
|
|
+
|
|
|
|
|
+from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, AutoTokenizer, set_seed
|
|
|
|
|
+from transformers import GPT2Tokenizer, GPT2Model, T5Tokenizer, T5ForConditionalGeneration
|
|
|
|
|
+import torch
|
|
|
|
|
+import sys
|
|
|
|
|
+from sys import argv
|
|
|
|
|
+import time
|
|
|
|
|
+import glob
|
|
|
|
|
+import psutil
|
|
|
|
|
+
|
|
|
|
|
+print(f"torch-dumps/{argv[1].split('/')[-2]}.pt")
|
|
|
|
|
+
|
|
|
|
|
+if argv[1].endswith('pt'):
|
|
|
|
|
+ model = OPTForCausalLM.from_pretrained(argv[1], device_map="auto")
|
|
|
|
|
+ torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
|
|
|
|
|
+elif 'galactica' in argv[1].lower():
|
|
|
|
|
+ model = OPTForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
|
|
|
|
|
+ #model = OPTForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, load_in_8bit=True)
|
|
|
|
|
+ torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
|
|
|
|
|
+elif 'flan-t5' in argv[1].lower():
|
|
|
|
|
+ model = T5ForConditionalGeneration.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
|
|
|
|
|
+ torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
|
|
|
|
|
+else:
|
|
|
|
|
+ print("Loading the model")
|
|
|
|
|
+ model = AutoModelForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
|
|
|
|
|
+ print("Model loaded")
|
|
|
|
|
+ #model = AutoModelForCausalLM.from_pretrained(argv[1], device_map='auto', load_in_8bit=True)
|
|
|
|
|
+ torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
|
|
|
|
|
+
|