convert-to-torch.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. '''
  2. Converts a transformers model to .pt, which is faster to load.
  3. Run with python convert.py /path/to/model/
  4. Make sure to write /path/to/model/ with a trailing / and not
  5. /path/to/model
  6. Output will be written to torch-dumps/name-of-the-model.pt
  7. '''
  8. from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, AutoTokenizer, set_seed
  9. from transformers import GPT2Tokenizer, GPT2Model, T5Tokenizer, T5ForConditionalGeneration
  10. import torch
  11. import sys
  12. from sys import argv
  13. import time
  14. import glob
  15. import psutil
  16. print(f"torch-dumps/{argv[1].split('/')[-2]}.pt")
  17. if argv[1].endswith('pt'):
  18. model = OPTForCausalLM.from_pretrained(argv[1], device_map="auto")
  19. torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
  20. elif 'galactica' in argv[1].lower():
  21. model = OPTForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
  22. #model = OPTForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, load_in_8bit=True)
  23. torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
  24. elif 'flan-t5' in argv[1].lower():
  25. model = T5ForConditionalGeneration.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
  26. torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
  27. else:
  28. print("Loading the model")
  29. model = AutoModelForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
  30. print("Model loaded")
  31. #model = AutoModelForCausalLM.from_pretrained(argv[1], device_map='auto', load_in_8bit=True)
  32. torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")