| 123456789101112131415161718192021222324252627 |
- '''
- Converts a transformers model to .pt, which is faster to load.
-
- Example:
- python convert.py models/opt-1.3b
-
- Output will be written to torch-dumps/name-of-the-model.pt
- '''
-
- from transformers import AutoModelForCausalLM, T5ForConditionalGeneration
- import torch
- from sys import argv
-
- path = argv[1]
- if path[-1] != '/':
- path = path+'/'
- model_name = path.split('/')[-2]
- print(f"Loading {model_name}...")
- if model_name in ['flan-t5', 't5-large']:
- model = T5ForConditionalGeneration.from_pretrained(path).cuda()
- else:
- model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
- print("Model loaded.")
- print(f"Saving to torch-dumps/{model_name}.pt")
- torch.save(model, f"torch-dumps/{model_name}.pt")
|