| 1234567891011121314151617181920212223242526272829303132333435363738 |
- '''
- Converts a transformers model to .pt, which is faster to load.
-
- Run with python convert.py /path/to/model/
- Make sure to write /path/to/model/ with a trailing / and not
- /path/to/model
-
- Output will be written to torch-dumps/name-of-the-model.pt
- '''
-
- from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, AutoTokenizer, set_seed
- from transformers import GPT2Tokenizer, GPT2Model, T5Tokenizer, T5ForConditionalGeneration
- import torch
- import sys
- from sys import argv
- import time
- import glob
- import psutil
-
- print(f"torch-dumps/{argv[1].split('/')[-2]}.pt")
-
- if argv[1].endswith('pt'):
- model = OPTForCausalLM.from_pretrained(argv[1], device_map="auto")
- torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
- elif 'galactica' in argv[1].lower():
- model = OPTForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
- #model = OPTForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, load_in_8bit=True)
- torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
- elif 'flan-t5' in argv[1].lower():
- model = T5ForConditionalGeneration.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
- torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
- else:
- print("Loading the model")
- model = AutoModelForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
- print("Model loaded")
- #model = AutoModelForCausalLM.from_pretrained(argv[1], device_map='auto', load_in_8bit=True)
- torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
|