Procházet zdrojové kódy

Add LLaMA 8-bit support

oobabooga před 2 roky
rodič
revize
bd8aac8fa4
2 změnil soubory, kde provedl 137 přidání a 4 odebrání
  1. 125 0
      modules/LLaMA_8bit.py
  2. 12 4
      modules/models.py

+ 125 - 0
modules/LLaMA_8bit.py

@@ -0,0 +1,125 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the GNU General Public License version 3.
+
+from typing import Tuple
+import os
+import sys
+import torch
+import fire
+import time
+import json
+
+from pathlib import Path
+
+from fairscale.nn.model_parallel.initialize import initialize_model_parallel
+
+from repositories.llama_int8.llama import ModelArgs, Transformer, Tokenizer, LLaMA
+
+
+def setup_model_parallel() -> Tuple[int, int]:
+    local_rank = int(os.environ.get("LOCAL_RANK", -1))
+    world_size = int(os.environ.get("WORLD_SIZE", -1))
+
+    torch.distributed.init_process_group("nccl")
+    initialize_model_parallel(world_size)
+    torch.cuda.set_device(local_rank)
+
+    # seed must be the same in all processes
+    torch.manual_seed(1)
+    return local_rank, world_size
+
+
+def load(
+    ckpt_dir: str,
+    tokenizer_path: str,
+    max_seq_len: int,
+    max_batch_size: int,
+) -> LLaMA:
+    start_time = time.time()
+    checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
+
+    with open(Path(ckpt_dir) / "params.json", "r") as f:
+        params = json.loads(f.read())
+
+    model_args: ModelArgs = ModelArgs(
+        max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
+    )
+    tokenizer = Tokenizer(model_path=tokenizer_path)
+    model_args.vocab_size = tokenizer.n_words
+    # torch.set_default_tensor_type(torch.cuda.HalfTensor)
+    torch.set_default_tensor_type(torch.HalfTensor)
+    print("Creating transformer")
+    model = Transformer(model_args)
+    print("Transformer created")
+
+    key_to_dim = {
+        "w1": 0,
+        "w2": -1,
+        "w3": 0,
+        "wo": -1,
+        "wq": 0,
+        "wk": 0,
+        "wv": 0,
+        "output": 0,
+        "tok_embeddings": -1,
+        "ffn_norm": None,
+        "attention_norm": None,
+        "norm": None,
+        "rope": None,
+    }
+
+    # ?
+    torch.set_default_tensor_type(torch.FloatTensor)
+
+    # load the state dict incrementally, to avoid memory problems
+    for i, ckpt in enumerate(checkpoints):
+        print(f"Loading checkpoint {i}")
+        checkpoint = torch.load(ckpt, map_location="cpu")
+        for parameter_name, parameter in model.named_parameters():
+            short_name = parameter_name.split(".")[-2]
+            if key_to_dim[short_name] is None and i == 0:
+                parameter.data = checkpoint[parameter_name]
+            elif key_to_dim[short_name] == 0:
+                size = checkpoint[parameter_name].size(0)
+                parameter.data[size * i : size * (i + 1), :] = checkpoint[
+                    parameter_name
+                ]
+            elif key_to_dim[short_name] == -1:
+                size = checkpoint[parameter_name].size(-1)
+                parameter.data[:, size * i : size * (i + 1)] = checkpoint[
+                    parameter_name
+                ]
+        del checkpoint
+
+    # model.load_state_dict(checkpoint, strict=False)
+    model.quantize()
+
+    generator = LLaMA(model, tokenizer)
+    print(f"Loaded in {time.time() - start_time:.2f} seconds")
+    return generator
+
+
+class LLaMAModel_8bit:
+    def __init__(self):
+        pass
+
+    @classmethod
+    def from_pretrained(self, path, max_seq_len=2048, max_batch_size=1):
+        tokenizer_path = path / "tokenizer.model"
+        path = os.path.abspath(path)
+        tokenizer_path = os.path.abspath(tokenizer_path)
+        
+        generator = load(path, tokenizer_path, max_seq_len, max_batch_size)
+
+        result = self()
+        result.pipeline = generator
+        return result
+
+    def generate(self, prompt, token_count=512, temperature=0.8, top_p=0.95):
+
+        results = self.pipeline.generate(
+            [prompt], max_gen_len=token_count, temperature=temperature, top_p=top_p
+        )
+
+        return results[0]
+

+ 12 - 4
modules/models.py

@@ -88,12 +88,20 @@ def load_model(model_name):
 
     # LLaMA model (not on HuggingFace)
     elif shared.is_LLaMA:
-        import modules.LLaMA
-        from modules.LLaMA import LLaMAModel
+        if shared.args.load_in_8bit:
+            import modules.LLaMA_8bit
+            from modules.LLaMA_8bit import LLaMAModel_8bit
 
-        model = LLaMAModel.from_pretrained(Path(f'models/{model_name}'))
+            model = LLaMAModel_8bit.from_pretrained(Path(f'models/{model_name}'))
 
-        return model, None
+            return model, None
+        else:
+            import modules.LLaMA
+            from modules.LLaMA import LLaMAModel
+
+            model = LLaMAModel.from_pretrained(Path(f'models/{model_name}'))
+
+            return model, None
 
     # Custom
     else: