浏览代码

Use **kwargs in generate_chat_prompt

oobabooga 2 年之前
父节点
当前提交
97e8ea219b
共有 2 个文件被更改,包括 10 次插入5 次删除
  1. 9 4
      modules/chat.py
  2. 1 1
      server.py

+ 9 - 4
modules/chat.py

@@ -18,7 +18,12 @@ from modules.text_generation import (encode, generate_reply,
                                      get_max_prompt_length)
 
 
-def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn="", impersonate=False, also_return_rows=False):
+def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs):
+    is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
+    end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
+    impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
+    also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
+
     user_input = fix_newlines(user_input)
     rows = [f"{context.strip()}\n"]
 
@@ -112,11 +117,11 @@ def chatbot_wrapper(text, max_new_tokens, generation_params, seed, name1, name2,
         visible_text = text
     text = apply_extensions(text, "input")
 
-    is_instruct = mode == 'instruct'
+    kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
     if custom_generate_chat_prompt is None:
-        prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn=end_of_turn)
+        prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs)
     else:
-        prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn=end_of_turn)
+        prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs)
 
     # Yield *Is typing...*
     if not regenerate:

+ 1 - 1
server.py

@@ -474,7 +474,7 @@ def create_interface():
 
             shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'generation_state', 'seed']]
             output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
-            gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
+            gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")