LLaMA.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the GNU General Public License version 3.
  3. import json
  4. import os
  5. import sys
  6. import time
  7. from pathlib import Path
  8. from typing import Tuple
  9. import fire
  10. import torch
  11. from fairscale.nn.model_parallel.initialize import initialize_model_parallel
  12. from llama import LLaMA, ModelArgs, Tokenizer, Transformer
  13. os.environ['RANK'] = '0'
  14. os.environ['WORLD_SIZE'] = '1'
  15. os.environ['MP'] = '1'
  16. os.environ['MASTER_ADDR'] = '127.0.0.1'
  17. os.environ['MASTER_PORT'] = '2223'
  18. def setup_model_parallel() -> Tuple[int, int]:
  19. local_rank = int(os.environ.get("LOCAL_RANK", -1))
  20. world_size = int(os.environ.get("WORLD_SIZE", -1))
  21. torch.distributed.init_process_group("gloo")
  22. initialize_model_parallel(world_size)
  23. torch.cuda.set_device(local_rank)
  24. # seed must be the same in all processes
  25. torch.manual_seed(1)
  26. return local_rank, world_size
  27. def load(
  28. ckpt_dir: str,
  29. tokenizer_path: str,
  30. local_rank: int,
  31. world_size: int,
  32. max_seq_len: int,
  33. max_batch_size: int,
  34. ) -> LLaMA:
  35. start_time = time.time()
  36. checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
  37. assert world_size == len(
  38. checkpoints
  39. ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
  40. ckpt_path = checkpoints[local_rank]
  41. print("Loading")
  42. checkpoint = torch.load(ckpt_path, map_location="cpu")
  43. with open(Path(ckpt_dir) / "params.json", "r") as f:
  44. params = json.loads(f.read())
  45. model_args: ModelArgs = ModelArgs(
  46. max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
  47. )
  48. tokenizer = Tokenizer(model_path=tokenizer_path)
  49. model_args.vocab_size = tokenizer.n_words
  50. torch.set_default_tensor_type(torch.cuda.HalfTensor)
  51. model = Transformer(model_args)
  52. torch.set_default_tensor_type(torch.FloatTensor)
  53. model.load_state_dict(checkpoint, strict=False)
  54. generator = LLaMA(model, tokenizer)
  55. print(f"Loaded in {time.time() - start_time:.2f} seconds")
  56. return generator
  57. class LLaMAModel:
  58. def __init__(self):
  59. pass
  60. @classmethod
  61. def from_pretrained(self, path, max_seq_len=512, max_batch_size=32):
  62. tokenizer_path = path / "tokenizer.model"
  63. path = os.path.abspath(path)
  64. tokenizer_path = os.path.abspath(tokenizer_path)
  65. local_rank, world_size = setup_model_parallel()
  66. if local_rank > 0:
  67. sys.stdout = open(os.devnull, "w")
  68. generator = load(
  69. path, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
  70. )
  71. result = self()
  72. result.pipeline = generator
  73. return result
  74. def generate(self, prompt, token_count=512, temperature=0.8, top_p=0.95):
  75. results = self.pipeline.generate(
  76. [prompt], max_gen_len=token_count, temperature=temperature, top_p=top_p
  77. )
  78. return results[0]