Browse Source

Remove "eval" statements from text generation functions

oobabooga 2 năm trước cách đây
mục cha
commit
afc5339510
1 tập tin đã thay đổi với 34 bổ sung31 xóa
  1. 34 31
      modules/text_generation.py

+ 34 - 31
modules/text_generation.py

@@ -122,7 +122,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
     input_ids = encode(question, max_new_tokens)
     original_input_ids = input_ids
     output = input_ids[0]
-    cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
+    cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
     eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
     if eos_token is not None:
         eos_token_ids.append(int(encode(eos_token)[0][-1]))
@@ -132,45 +132,48 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
         t = encode(stopping_string, 0, add_special_tokens=False)
         stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
 
+    generate_params = {}
     if not shared.args.flexgen:
-        generate_params = [
-            f"max_new_tokens=max_new_tokens",
-            f"eos_token_id={eos_token_ids}",
-            f"stopping_criteria=stopping_criteria_list",
-            f"do_sample={do_sample}",
-            f"temperature={temperature}",
-            f"top_p={top_p}",
-            f"typical_p={typical_p}",
-            f"repetition_penalty={repetition_penalty}",
-            f"top_k={top_k}",
-            f"min_length={min_length if shared.args.no_stream else 0}",
-            f"no_repeat_ngram_size={no_repeat_ngram_size}",
-            f"num_beams={num_beams}",
-            f"penalty_alpha={penalty_alpha}",
-            f"length_penalty={length_penalty}",
-            f"early_stopping={early_stopping}",
-        ]
+        generate_params.update({
+            "max_new_tokens": max_new_tokens,
+            "eos_token_id": eos_token_ids,
+            "stopping_criteria": stopping_criteria_list,
+            "do_sample": do_sample,
+            "temperature": temperature,
+            "top_p": top_p,
+            "typical_p": typical_p,
+            "repetition_penalty": repetition_penalty,
+            "top_k": top_k,
+            "min_length": min_length if shared.args.no_stream else 0,
+            "no_repeat_ngram_size": no_repeat_ngram_size,
+            "num_beams": num_beams,
+            "penalty_alpha": penalty_alpha,
+            "length_penalty": length_penalty,
+            "early_stopping": early_stopping,
+        })
     else:
-        generate_params = [
-            f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}",
-            f"do_sample={do_sample}",
-            f"temperature={temperature}",
-            f"stop={eos_token_ids[-1]}",
-        ]
+        generate_params.update({
+            "max_new_tokens": max_new_tokens if shared.args.no_stream else 8,
+            "do_sample": do_sample,
+            "temperature": temperature,
+            "stop": eos_token_ids[-1],
+        })
     if shared.args.deepspeed:
-        generate_params.append("synced_gpus=True")
+        generate_params.update({"synced_gpus": True})
     if shared.soft_prompt:
         inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
-        generate_params.insert(0, "inputs_embeds=inputs_embeds")
-        generate_params.insert(0, "inputs=filler_input_ids")
+        generate_params.update({"inputs_embeds": inputs_embeds})
+        generate_params.update({"inputs": filler_input_ids})
     else:
-        generate_params.insert(0, "inputs=input_ids")
+        generate_params.update({"inputs": input_ids})
 
     try:
         # Generate the entire reply at once.
         if shared.args.no_stream:
             with torch.no_grad():
-                output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
+                output = shared.model.generate(**generate_params)[0]
+                if cuda:
+                    output = output.cuda()
             if shared.soft_prompt:
                 output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
 
@@ -194,7 +197,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
                 return Iteratorize(generate_with_callback, kwargs, callback=None)
 
             yield formatted_outputs(original_question, shared.model_name)
-            with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator:
+            with generate_with_streaming(**generate_params) as generator:
                 for output in generator:
                     if shared.soft_prompt:
                         output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
@@ -214,7 +217,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
             for i in range(max_new_tokens//8+1):
                 clear_torch_cache()
                 with torch.no_grad():
-                    output = eval(f"shared.model.generate({', '.join(generate_params)})")[0]
+                    output = shared.model.generate(**generate_params)[0]
                 if shared.soft_prompt:
                     output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
                 reply = decode(output)