convert-to-flexgen.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. '''
  2. Converts a transformers model to a format compatible with flexgen.
  3. '''
  4. import argparse
  5. import os
  6. from pathlib import Path
  7. import numpy as np
  8. import torch
  9. from tqdm import tqdm
  10. from transformers import AutoModelForCausalLM
  11. from transformers import AutoTokenizer
  12. parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
  13. parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
  14. args = parser.parse_args()
  15. def disable_torch_init():
  16. """
  17. Disable the redundant torch default initialization to accelerate model creation.
  18. """
  19. import torch
  20. global torch_linear_init_backup
  21. global torch_layer_norm_init_backup
  22. torch_linear_init_backup = torch.nn.Linear.reset_parameters
  23. setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
  24. torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
  25. setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
  26. def restore_torch_init():
  27. """Rollback the change made by disable_torch_init."""
  28. import torch
  29. setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
  30. setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
  31. if __name__ == '__main__':
  32. path = Path(args.MODEL)
  33. model_name = path.name
  34. print(f"Loading {model_name}...")
  35. #disable_torch_init()
  36. model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
  37. #restore_torch_init()
  38. tokenizer = AutoTokenizer.from_pretrained(path)
  39. out_folder = Path(f"models/{model_name}-np")
  40. if not Path(out_folder).exists():
  41. os.mkdir(out_folder)
  42. print(f"Saving the converted model to {out_folder}...")
  43. for name, param in tqdm(list(model.model.named_parameters())):
  44. name = name.replace("decoder.final_layer_norm", "decoder.layer_norm")
  45. param_path = os.path.join(out_folder, name)
  46. with open(param_path, "wb") as f:
  47. np.save(f, param.cpu().detach().numpy())