ソースを参照

Simplify deepspeed implementation (#40)

oobabooga 3 年 前
コミット
2583bc5840
1 ファイル変更18 行追加26 行削除
  1. 18 26
      server.py

+ 18 - 26
server.py

@@ -83,10 +83,7 @@ if args.deepspeed:
     from modules.deepspeed_parameters import generate_ds_config
 
     # Distributed setup
-    if args.local_rank is not None:
-        local_rank = args.local_rank
-    else:
-        local_rank = int(os.getenv("LOCAL_RANK", "0"))
+    local_rank = args.local_rank if args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
     world_size = int(os.getenv("WORLD_SIZE", "1"))
     torch.cuda.set_device(local_rank)
     deepspeed.init_distributed()
@@ -109,15 +106,8 @@ def load_model(model_name):
 
     # DeepSpeed ZeRO-3
     elif args.deepspeed:
-        if args.bf16:
-            model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.bfloat16)
-        else:
-            model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.float16)
-        model = deepspeed.initialize(model=model,
-                                     config_params=ds_config,
-                                     model_parameters=None,
-                                     optimizer=None,
-                                     lr_scheduler=None)[0]
+        model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.bfloat16 if args.bf16 else torch.float16)
+        model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
         model.module.eval() # Inference
         print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
 
@@ -183,7 +173,11 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
     else:
         torch.cuda.empty_cache()
         input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens).cuda()
-    return input_ids
+
+    if not args.deepspeed:
+        return input_ids
+    else:
+        return input_ids.to(device=local_rank)
 
 def decode(output_ids):
     reply = tokenizer.decode(output_ids, skip_special_tokens=True)
@@ -226,10 +220,8 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
 
     cuda = "" if args.cpu else ".cuda()"
     n = tokenizer.eos_token_id if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
-    if args.deepspeed:
-        input_ids = encode(question, tokens).to(device=local_rank)
-    else:
-        input_ids = encode(question, tokens)
+    input_ids = encode(question, tokens)
+
     if stopping_string is not None:
         # The stopping_criteria code below was copied from
         # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
@@ -246,11 +238,11 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
     # Generate the entire reply at once
     if args.no_stream:
         t0 = time.time()
-        if args.deepspeed:
-            with torch.no_grad():
+        with torch.no_grad():
+            if not args.deepspeed:
+                output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
+            else:
                 output = eval(f"model.generate(input_ids, synced_gpus=True, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset})")
-        else:
-            output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
         reply = decode(output[0])
         t1 = time.time()
         print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)")
@@ -263,11 +255,11 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
         yield formatted_outputs(original_question, model_name)
         preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=8')
         for i in tqdm(range(tokens//8+1)):
-            if args.deepspeed:
-                with torch.no_grad():
+            with torch.no_grad():
+                if not args.deepspeed:
+                    output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
+                else:
                     output = eval(f"model.generate(input_ids, synced_gpus=True, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset})")
-            else:
-                output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
             reply = decode(output[0])
             if not (args.chat or args.cai_chat):
                 reply = original_question + apply_extensions(reply[len(question):], "output")