convert-to-flexgen.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. '''
  2. Converts a transformers model to a format compatible with flexgen.
  3. '''
  4. import argparse
  5. import os
  6. import numpy as np
  7. from pathlib import Path
  8. from sys import argv
  9. import torch
  10. from tqdm import tqdm
  11. from transformers import AutoModelForCausalLM
  12. from transformers import AutoTokenizer
  13. parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
  14. parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
  15. args = parser.parse_args()
  16. def disable_torch_init():
  17. """
  18. Disable the redundant torch default initialization to accelerate model creation.
  19. """
  20. import torch
  21. global torch_linear_init_backup
  22. global torch_layer_norm_init_backup
  23. torch_linear_init_backup = torch.nn.Linear.reset_parameters
  24. setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
  25. torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
  26. setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
  27. def restore_torch_init():
  28. """Rollback the change made by disable_torch_init."""
  29. import torch
  30. setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
  31. setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
  32. if __name__ == '__main__':
  33. path = Path(args.MODEL)
  34. model_name = path.name
  35. print(f"Loading {model_name}...")
  36. #disable_torch_init()
  37. model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
  38. #restore_torch_init()
  39. tokenizer = AutoTokenizer.from_pretrained(path)
  40. out_folder = Path(f"models/{model_name}-np")
  41. if not Path(out_folder).exists():
  42. os.mkdir(out_folder)
  43. print(f"Saving the converted model to {out_folder}...")
  44. for name, param in tqdm(list(model.model.named_parameters())):
  45. name = name.replace("decoder.final_layer_norm", "decoder.layer_norm")
  46. param_path = os.path.join(out_folder, name)
  47. with open(param_path, "wb") as f:
  48. np.save(f, param.cpu().detach().numpy())