Explorar el Código

Add a Continue button to chat mode

oobabooga hace 2 años
padre
commit
d29f4624e9
Se han modificado 2 ficheros con 44 adiciones y 12 borrados
  1. 37 10
      modules/chat.py
  2. 7 2
      server.py

+ 37 - 10
modules/chat.py

@@ -22,6 +22,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
     is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
     end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
     impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
+    _continue = kwargs['_continue'] if '_continue' in kwargs else False
     also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
     rows = [f"{context.strip()}\n"]
 
@@ -39,7 +40,10 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
 
     i = len(shared.history['internal']) - 1
     while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
-        rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
+        if _continue and i == len(shared.history['internal']) - 1:
+            rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
+        else:
+            rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
         string = shared.history['internal'][i][0]
         if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
             rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
@@ -48,6 +52,8 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
     if impersonate:
         rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
         limit = 2
+    elif _continue:
+        limit = 3
     else:
         # Adding the user message
         user_input = fix_newlines(user_input)
@@ -56,12 +62,12 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
 
         # Adding the Character prefix
         rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
+
         limit = 3
 
     while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
         rows.pop(1)
     prompt = ''.join(rows)
-
     if also_return_rows:
         return prompt, rows
     else:
@@ -99,7 +105,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
     return reply, next_character_found
 
 
-def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
+def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False, _continue=False):
     if mode == 'instruct':
         stopping_strings = [f"\n{name1}", f"\n{name2}"]
     else:
@@ -107,6 +113,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
 
     # Defining some variables
     cumulative_reply = ''
+    last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
     just_started = True
     name1_original = name1
     visible_text = custom_generate_chat_prompt = None
@@ -124,17 +131,22 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
 
     if visible_text is None:
         visible_text = text
-    text = apply_extensions(text, "input")
+    if not _continue:
+        text = apply_extensions(text, "input")
 
     # Generating the prompt
-    kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
+    kwargs = {
+        'end_of_turn': end_of_turn,
+        'is_instruct': mode == 'instruct',
+        '_continue': _continue
+    }
     if custom_generate_chat_prompt is None:
         prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
     else:
         prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
 
     # Yield *Is typing...*
-    if not regenerate:
+    if not any((regenerate, _continue)):
         yield shared.history['visible'] + [[visible_text, shared.processing_message]]
 
     # Generate
@@ -154,11 +166,16 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
                 return shared.history['visible']
             if just_started:
                 just_started = False
-                shared.history['internal'].append(['', ''])
-                shared.history['visible'].append(['', ''])
+                if not _continue:
+                    shared.history['internal'].append(['', ''])
+                    shared.history['visible'].append(['', ''])
 
-            shared.history['internal'][-1] = [text, reply]
-            shared.history['visible'][-1] = [visible_text, visible_reply]
+            if _continue:
+                shared.history['internal'][-1] = [text, f'{last_reply[0]} {reply}']
+                shared.history['visible'][-1] = [visible_text, f'{last_reply[1]} {visible_reply}']
+            else:
+                shared.history['internal'][-1] = [text, reply]
+                shared.history['visible'][-1] = [visible_text, visible_reply]
             if not shared.args.no_stream:
                 yield shared.history['visible']
             if next_character_found:
@@ -220,6 +237,16 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of
             yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
 
 
+def continue_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
+    if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
+        yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
+    else:
+        # Yield ' ...'
+        yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], name1, name2, mode)
+        for history in chatbot_wrapper(shared.history['internal'][-1][0], generate_state, name1, name2, context, mode, end_of_turn, _continue=True):
+            yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
+
+
 def remove_last_message(name1, name2, mode):
     if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
         last = shared.history['visible'].pop()

+ 7 - 2
server.py

@@ -327,8 +327,9 @@ def create_interface():
                     shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
                     shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
                 with gr.Row():
-                    shared.gradio['Impersonate'] = gr.Button('Impersonate')
                     shared.gradio['Regenerate'] = gr.Button('Regenerate')
+                    shared.gradio['Continue'] = gr.Button('Continue')
+                    shared.gradio['Impersonate'] = gr.Button('Impersonate')
                 with gr.Row():
                     shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
                     shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
@@ -411,7 +412,11 @@ def create_interface():
 
             gen_events.append(shared.gradio['Regenerate'].click(
                 chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
-                lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
+                lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
+            )
+
+            gen_events.append(shared.gradio['Continue'].click(
+                chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
                 lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
             )