| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the GNU General Public License version 3.
- from typing import Tuple
- import os
- import sys
- import torch
- import fire
- import time
- import json
- from pathlib import Path
- from fairscale.nn.model_parallel.initialize import initialize_model_parallel
- from repositories.llama_int8.llama import ModelArgs, Transformer, Tokenizer, LLaMA
- def setup_model_parallel() -> Tuple[int, int]:
- local_rank = int(os.environ.get("LOCAL_RANK", -1))
- world_size = int(os.environ.get("WORLD_SIZE", -1))
- torch.distributed.init_process_group("nccl")
- initialize_model_parallel(world_size)
- torch.cuda.set_device(local_rank)
- # seed must be the same in all processes
- torch.manual_seed(1)
- return local_rank, world_size
- def load(
- ckpt_dir: str,
- tokenizer_path: str,
- max_seq_len: int,
- max_batch_size: int,
- ) -> LLaMA:
- start_time = time.time()
- checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
- with open(Path(ckpt_dir) / "params.json", "r") as f:
- params = json.loads(f.read())
- model_args: ModelArgs = ModelArgs(
- max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
- )
- tokenizer = Tokenizer(model_path=tokenizer_path)
- model_args.vocab_size = tokenizer.n_words
- # torch.set_default_tensor_type(torch.cuda.HalfTensor)
- torch.set_default_tensor_type(torch.HalfTensor)
- print("Creating transformer")
- model = Transformer(model_args)
- print("Transformer created")
- key_to_dim = {
- "w1": 0,
- "w2": -1,
- "w3": 0,
- "wo": -1,
- "wq": 0,
- "wk": 0,
- "wv": 0,
- "output": 0,
- "tok_embeddings": -1,
- "ffn_norm": None,
- "attention_norm": None,
- "norm": None,
- "rope": None,
- }
- # ?
- torch.set_default_tensor_type(torch.FloatTensor)
- # load the state dict incrementally, to avoid memory problems
- for i, ckpt in enumerate(checkpoints):
- print(f"Loading checkpoint {i}")
- checkpoint = torch.load(ckpt, map_location="cpu")
- for parameter_name, parameter in model.named_parameters():
- short_name = parameter_name.split(".")[-2]
- if key_to_dim[short_name] is None and i == 0:
- parameter.data = checkpoint[parameter_name]
- elif key_to_dim[short_name] == 0:
- size = checkpoint[parameter_name].size(0)
- parameter.data[size * i : size * (i + 1), :] = checkpoint[
- parameter_name
- ]
- elif key_to_dim[short_name] == -1:
- size = checkpoint[parameter_name].size(-1)
- parameter.data[:, size * i : size * (i + 1)] = checkpoint[
- parameter_name
- ]
- del checkpoint
- # model.load_state_dict(checkpoint, strict=False)
- model.quantize()
- generator = LLaMA(model, tokenizer)
- print(f"Loaded in {time.time() - start_time:.2f} seconds")
- return generator
- class LLaMAModel_8bit:
- def __init__(self):
- pass
- @classmethod
- def from_pretrained(self, path, max_seq_len=2048, max_batch_size=1):
- tokenizer_path = path / "tokenizer.model"
- path = os.path.abspath(path)
- tokenizer_path = os.path.abspath(tokenizer_path)
-
- generator = load(path, tokenizer_path, max_seq_len, max_batch_size)
- result = self()
- result.pipeline = generator
- return result
- def generate(self, prompt, token_count=512, temperature=0.8, top_p=0.95):
- results = self.pipeline.generate(
- [prompt], max_gen_len=token_count, temperature=temperature, top_p=top_p
- )
- return results[0]
|