convert-to-torch.py 615 B

1234567891011121314151617181920212223
  1. '''
  2. Converts a transformers model to .pt, which is faster to load.
  3. Example:
  4. python convert-to-torch.py models/opt-1.3b
  5. The output will be written to torch-dumps/name-of-the-model.pt
  6. '''
  7. from transformers import AutoModelForCausalLM
  8. import torch
  9. from sys import argv
  10. from pathlib import Path
  11. path = Path(argv[1])
  12. model_name = path.name
  13. print(f"Loading {model_name}...")
  14. model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
  15. print("Model loaded.")
  16. print(f"Saving to torch-dumps/{model_name}.pt")
  17. torch.save(model, Path(f"torch-dumps/{model_name}.pt"))