| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- '''
- Converts a transformers model to a format compatible with flexgen.
- '''
- import argparse
- import os
- from pathlib import Path
- import numpy as np
- import torch
- from tqdm import tqdm
- from transformers import AutoModelForCausalLM, AutoTokenizer
- parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
- parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
- args = parser.parse_args()
- def disable_torch_init():
- """
- Disable the redundant torch default initialization to accelerate model creation.
- """
- import torch
- global torch_linear_init_backup
- global torch_layer_norm_init_backup
- torch_linear_init_backup = torch.nn.Linear.reset_parameters
- setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
- torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
- setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
- def restore_torch_init():
- """Rollback the change made by disable_torch_init."""
- import torch
- setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
- setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
- if __name__ == '__main__':
- path = Path(args.MODEL)
- model_name = path.name
- print(f"Loading {model_name}...")
- # disable_torch_init()
- model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
- # restore_torch_init()
- tokenizer = AutoTokenizer.from_pretrained(path)
- out_folder = Path(f"models/{model_name}-np")
- if not Path(out_folder).exists():
- os.mkdir(out_folder)
- print(f"Saving the converted model to {out_folder}...")
- for name, param in tqdm(list(model.model.named_parameters())):
- name = name.replace("decoder.final_layer_norm", "decoder.layer_norm")
- param_path = os.path.join(out_folder, name)
- with open(param_path, "wb") as f:
- np.save(f, param.cpu().detach().numpy())
|