Browse Source

Add FlexGen support #92 (experimental)

oobabooga 2 năm trước cách đây
mục cha
commit
b83f51ee04
2 tập tin đã thay đổi với 123 bổ sung20 xóa
  1. 63 0
      convert-to-flexgen.py
  2. 60 20
      server.py

+ 63 - 0
convert-to-flexgen.py

@@ -0,0 +1,63 @@
+'''
+
+Converts a transformers model to a format compatible with flexgen.
+
+'''
+
+import argparse
+import os
+import numpy as np
+from pathlib import Path
+from sys import argv
+
+import torch
+from tqdm import tqdm
+from transformers import AutoModelForCausalLM
+from transformers import AutoTokenizer
+ 
+parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
+parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
+args = parser.parse_args()
+
+def disable_torch_init():
+    """
+    Disable the redundant torch default initialization to accelerate model creation.
+    """
+    import torch
+    global torch_linear_init_backup
+    global torch_layer_norm_init_backup
+
+    torch_linear_init_backup = torch.nn.Linear.reset_parameters
+    setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+
+    torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
+    setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def restore_torch_init():
+    """Rollback the change made by disable_torch_init."""
+    import torch
+    setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
+    setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
+
+if __name__ == '__main__':
+    path = Path(args.MODEL)
+    model_name = path.name
+
+    print(f"Loading {model_name}...")
+    disable_torch_init()
+    model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, _fast_init=True)
+    restore_torch_init()
+
+    tokenizer = AutoTokenizer.from_pretrained(path)
+
+    out_folder = Path(f"models/{model_name}-np")
+    if not Path(out_folder).exists():
+        os.mkdir(out_folder)
+
+    print(f"Saving the converted model to {out_folder}...")
+    for name, param in tqdm(list(model.model.named_parameters())):
+        name = name.replace("decoder.final_layer_norm", "decoder.layer_norm")
+        param_path = os.path.join(out_folder, name)
+        with open(param_path, "wb") as f:
+            np.save(f, param.cpu().detach().numpy())

+ 60 - 20
server.py

@@ -45,6 +45,7 @@ parser.add_argument('--disk', action='store_true', help='If the model is too lar
 parser.add_argument('--disk-cache-dir', type=str, help='Directory to save the disk cache to. Defaults to "cache/".')
 parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to allocate. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.')
 parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
+parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.')
 parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
 parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
 parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
@@ -86,6 +87,9 @@ if args.settings is not None and Path(args.settings).exists():
     for item in new_settings:
         settings[item] = new_settings[item]
 
+if args.flexgen:
+    from flexgen.flex_opt import (Policy, OptLM, TorchDevice, TorchDisk, TorchMixedDevice, CompressionConfig, Env, Task, get_opt_config)
+
 if args.deepspeed:
     import deepspeed
     from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_zero3_enabled
@@ -107,12 +111,39 @@ def load_model(model_name):
     t0 = time.time()
 
     # Default settings
-    if not (args.cpu or args.load_in_8bit or args.auto_devices or args.disk or args.gpu_memory is not None or args.cpu_memory is not None or args.deepspeed):
+    if not (args.cpu or args.load_in_8bit or args.auto_devices or args.disk or args.gpu_memory is not None or args.cpu_memory is not None or args.deepspeed or args.flexgen):
         if any(size in model_name.lower() for size in ('13b', '20b', '30b')):
             model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True)
         else:
             model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if args.bf16 else torch.float16).cuda()
 
+    # FlexGen
+    elif args.flexgen:
+        gpu = TorchDevice("cuda:0")
+        cpu = TorchDevice("cpu")
+        disk = TorchDisk("cache")
+        env = Env(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk]))
+
+        # Offloading policy
+        policy = Policy(1, 1,
+                        100, 0,
+                        100, 0,
+                        100, 0,
+                        overlap=True, sep_layer=True, pin_weight=True,
+                        cpu_cache_compute=False, attn_sparsity=1.0,
+                        compress_weight=False,
+                        comp_weight_config=CompressionConfig(
+                            num_bits=4, group_size=64,
+                            group_dim=0, symmetric=False),
+                        compress_cache=False,
+                        comp_cache_config=CompressionConfig(
+                            num_bits=4, group_size=64,
+                            group_dim=2, symmetric=False))
+
+        opt_config = get_opt_config(f"facebook/{model_name}")
+        model = OptLM(opt_config, env, "models", policy)
+        model.init_all_weights()
+
     # DeepSpeed ZeRO-3
     elif args.deepspeed:
         model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.bfloat16 if args.bf16 else torch.float16)
@@ -273,7 +304,7 @@ def get_max_prompt_length(tokens):
 
 def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
     input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
-    if args.cpu:
+    if args.cpu or args.flexgen:
         return input_ids
     elif args.deepspeed:
         return input_ids.to(device=local_rank)
@@ -315,7 +346,7 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
         print(f"\n\n{question}\n--------------------\n")
 
     input_ids = encode(question, tokens)
-    cuda = "" if (args.cpu or args.deepspeed) else ".cuda()"
+    cuda = "" if (args.cpu or args.deepspeed or args.flexgen) else ".cuda()"
     n = tokenizer.eos_token_id if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
     if stopping_string is not None:
         # The stopping_criteria code below was copied from
@@ -330,22 +361,28 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
     else:
         stopping_criteria_list = None
 
-    generate_params = [
-        f"eos_token_id={n}",
-        f"stopping_criteria=stopping_criteria_list",
-        f"do_sample={do_sample}",
-        f"temperature={temperature}",
-        f"top_p={top_p}",
-        f"typical_p={typical_p}",
-        f"repetition_penalty={repetition_penalty}",
-        f"top_k={top_k}",
-        f"min_length={min_length if args.no_stream else 0}",
-        f"no_repeat_ngram_size={no_repeat_ngram_size}",
-        f"num_beams={num_beams}",
-        f"penalty_alpha={penalty_alpha}",
-        f"length_penalty={length_penalty}",
-        f"early_stopping={early_stopping}",
-    ]
+    if not args.flexgen:
+        generate_params = [
+            f"eos_token_id={n}",
+            f"stopping_criteria=stopping_criteria_list",
+            f"do_sample={do_sample}",
+            f"temperature={temperature}",
+            f"top_p={top_p}",
+            f"typical_p={typical_p}",
+            f"repetition_penalty={repetition_penalty}",
+            f"top_k={top_k}",
+            f"min_length={min_length if args.no_stream else 0}",
+            f"no_repeat_ngram_size={no_repeat_ngram_size}",
+            f"num_beams={num_beams}",
+            f"penalty_alpha={penalty_alpha}",
+            f"length_penalty={length_penalty}",
+            f"early_stopping={early_stopping}",
+        ]
+    else:
+        generate_params = [
+            f"do_sample={do_sample}",
+            f"temperature={temperature}",
+        ]
 
     if args.deepspeed:
         generate_params.append("synced_gpus=True")
@@ -391,7 +428,10 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
                 reply = original_question + apply_extensions(reply[len(question):], "output")
             yield formatted_outputs(reply, model_name)
 
-            input_ids = torch.reshape(output, (1, output.shape[0]))
+            if not args.flexgen:
+                input_ids = torch.reshape(output, (1, output.shape[0]))
+            else:
+                input_ids = np.reshape(output, (1, output.shape[0]))
             if soft_prompt:
                 inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)