فهرست منبع

Remove variables

oobabooga 2 سال پیش
والد
کامیت
849a54ef2d
1فایلهای تغییر یافته به همراه4 افزوده شده و 8 حذف شده
  1. 4 8
      modules/text_generation.py

+ 4 - 8
modules/text_generation.py

@@ -103,12 +103,8 @@ def stop_everything_event():
     shared.stop_everything = True
 
 def generate_reply(question, generate_params, eos_token=None, stopping_strings=[]):
-    max_new_tokens = generate_params['max_new_tokens']
-    seed = generate_params['seed']
-    print(generate_params)
-    print('---------------')
     clear_torch_cache()
-    set_manual_seed(seed)
+    set_manual_seed(generate_params['seed'])
     shared.stop_everything = False
     updated_params = {}
     t0 = time.time()
@@ -155,7 +151,7 @@ def generate_reply(question, generate_params, eos_token=None, stopping_strings=[
             print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
             return
 
-    input_ids = encode(question, max_new_tokens)
+    input_ids = encode(question, generate_params['max_new_tokens'])
     original_input_ids = input_ids
     output = input_ids[0]
 
@@ -168,7 +164,7 @@ def generate_reply(question, generate_params, eos_token=None, stopping_strings=[
         t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
         stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
 
-    updated_params["max_new_tokens"] = max_new_tokens
+    updated_params["max_new_tokens"] = generate_params['max_new_tokens']
     if not shared.args.flexgen:
         updated_params["eos_token_id"] = eos_token_ids
         updated_params["stopping_criteria"] = stopping_criteria_list
@@ -244,7 +240,7 @@ def generate_reply(question, generate_params, eos_token=None, stopping_strings=[
 
         # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
         else:
-            for i in range(max_new_tokens//8+1):
+            for i in range(generate_params['max_new_tokens']//8+1):
                 clear_torch_cache()
                 with torch.no_grad():
                     output = shared.model.generate(**updated_params)[0]