convert-to-torch.py 763 B

123456789101112131415161718192021222324252627
  1. '''
  2. Converts a transformers model to .pt, which is faster to load.
  3. Example:
  4. python convert.py models/opt-1.3b
  5. Output will be written to torch-dumps/name-of-the-model.pt
  6. '''
  7. from transformers import AutoModelForCausalLM, T5ForConditionalGeneration
  8. import torch
  9. from sys import argv
  10. path = argv[1]
  11. if path[-1] != '/':
  12. path = path+'/'
  13. model_name = path.split('/')[-2]
  14. print(f"Loading {model_name}...")
  15. if model_name in ['flan-t5', 't5-large']:
  16. model = T5ForConditionalGeneration.from_pretrained(path).cuda()
  17. else:
  18. model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
  19. print("Model loaded.")
  20. print(f"Saving to torch-dumps/{model_name}.pt")
  21. torch.save(model, f"torch-dumps/{model_name}.pt")