convert-to-flexgen.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. '''
  2. Converts a transformers model to a format compatible with flexgen.
  3. '''
  4. import argparse
  5. import os
  6. from pathlib import Path
  7. import numpy as np
  8. import torch
  9. from tqdm import tqdm
  10. from transformers import AutoModelForCausalLM, AutoTokenizer
  11. parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
  12. parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
  13. args = parser.parse_args()
  14. def disable_torch_init():
  15. """
  16. Disable the redundant torch default initialization to accelerate model creation.
  17. """
  18. import torch
  19. global torch_linear_init_backup
  20. global torch_layer_norm_init_backup
  21. torch_linear_init_backup = torch.nn.Linear.reset_parameters
  22. setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
  23. torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
  24. setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
  25. def restore_torch_init():
  26. """Rollback the change made by disable_torch_init."""
  27. import torch
  28. setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
  29. setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
  30. if __name__ == '__main__':
  31. path = Path(args.MODEL)
  32. model_name = path.name
  33. print(f"Loading {model_name}...")
  34. # disable_torch_init()
  35. model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
  36. # restore_torch_init()
  37. tokenizer = AutoTokenizer.from_pretrained(path)
  38. out_folder = Path(f"models/{model_name}-np")
  39. if not Path(out_folder).exists():
  40. os.mkdir(out_folder)
  41. print(f"Saving the converted model to {out_folder}...")
  42. for name, param in tqdm(list(model.model.named_parameters())):
  43. name = name.replace("decoder.final_layer_norm", "decoder.layer_norm")
  44. param_path = os.path.join(out_folder, name)
  45. with open(param_path, "wb") as f:
  46. np.save(f, param.cpu().detach().numpy())