LLaMA_8bit.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the GNU General Public License version 3.
  3. from typing import Tuple
  4. import os
  5. import sys
  6. import torch
  7. import fire
  8. import time
  9. import json
  10. from pathlib import Path
  11. from fairscale.nn.model_parallel.initialize import initialize_model_parallel
  12. from repositories.llama_int8.llama import ModelArgs, Transformer, Tokenizer, LLaMA
  13. def setup_model_parallel() -> Tuple[int, int]:
  14. local_rank = int(os.environ.get("LOCAL_RANK", -1))
  15. world_size = int(os.environ.get("WORLD_SIZE", -1))
  16. torch.distributed.init_process_group("nccl")
  17. initialize_model_parallel(world_size)
  18. torch.cuda.set_device(local_rank)
  19. # seed must be the same in all processes
  20. torch.manual_seed(1)
  21. return local_rank, world_size
  22. def load(
  23. ckpt_dir: str,
  24. tokenizer_path: str,
  25. max_seq_len: int,
  26. max_batch_size: int,
  27. ) -> LLaMA:
  28. start_time = time.time()
  29. checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
  30. with open(Path(ckpt_dir) / "params.json", "r") as f:
  31. params = json.loads(f.read())
  32. model_args: ModelArgs = ModelArgs(
  33. max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
  34. )
  35. tokenizer = Tokenizer(model_path=tokenizer_path)
  36. model_args.vocab_size = tokenizer.n_words
  37. # torch.set_default_tensor_type(torch.cuda.HalfTensor)
  38. torch.set_default_tensor_type(torch.HalfTensor)
  39. print("Creating transformer")
  40. model = Transformer(model_args)
  41. print("Transformer created")
  42. key_to_dim = {
  43. "w1": 0,
  44. "w2": -1,
  45. "w3": 0,
  46. "wo": -1,
  47. "wq": 0,
  48. "wk": 0,
  49. "wv": 0,
  50. "output": 0,
  51. "tok_embeddings": -1,
  52. "ffn_norm": None,
  53. "attention_norm": None,
  54. "norm": None,
  55. "rope": None,
  56. }
  57. # ?
  58. torch.set_default_tensor_type(torch.FloatTensor)
  59. # load the state dict incrementally, to avoid memory problems
  60. for i, ckpt in enumerate(checkpoints):
  61. print(f"Loading checkpoint {i}")
  62. checkpoint = torch.load(ckpt, map_location="cpu")
  63. for parameter_name, parameter in model.named_parameters():
  64. short_name = parameter_name.split(".")[-2]
  65. if key_to_dim[short_name] is None and i == 0:
  66. parameter.data = checkpoint[parameter_name]
  67. elif key_to_dim[short_name] == 0:
  68. size = checkpoint[parameter_name].size(0)
  69. parameter.data[size * i : size * (i + 1), :] = checkpoint[
  70. parameter_name
  71. ]
  72. elif key_to_dim[short_name] == -1:
  73. size = checkpoint[parameter_name].size(-1)
  74. parameter.data[:, size * i : size * (i + 1)] = checkpoint[
  75. parameter_name
  76. ]
  77. del checkpoint
  78. # model.load_state_dict(checkpoint, strict=False)
  79. model.quantize()
  80. generator = LLaMA(model, tokenizer)
  81. print(f"Loaded in {time.time() - start_time:.2f} seconds")
  82. return generator
  83. class LLaMAModel_8bit:
  84. def __init__(self):
  85. pass
  86. @classmethod
  87. def from_pretrained(self, path, max_seq_len=2048, max_batch_size=1):
  88. tokenizer_path = path / "tokenizer.model"
  89. path = os.path.abspath(path)
  90. tokenizer_path = os.path.abspath(tokenizer_path)
  91. generator = load(path, tokenizer_path, max_seq_len, max_batch_size)
  92. result = self()
  93. result.pipeline = generator
  94. return result
  95. def generate(self, prompt, token_count=512, temperature=0.8, top_p=0.95):
  96. results = self.pipeline.generate(
  97. [prompt], max_gen_len=token_count, temperature=temperature, top_p=top_p
  98. )
  99. return results[0]