Przeglądaj źródła

Implement text streaming (#10)

Still experimental. There might be bugs.
oobabooga 3 lat temu
rodzic
commit
0f01a3b1fa
1 zmienionych plików z 72 dodań i 51 usunięć
  1. 72 51
      server.py

+ 72 - 51
server.py

@@ -139,25 +139,28 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
             preset = infile.read()
         loaded_preset = inference_settings
 
-    input_ids = encode(question, tokens)
+    for i in range(tokens):
+        input_ids = encode(question, 1)
+        preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
 
-    cuda = ".cuda()" if args.cpu else ""
-    if eos_token is None:
-        output = eval(f"model.generate(input_ids, {preset}){cuda}")
-    else:
-        n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
-        output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
-
-    reply = tokenizer.decode(output[0], skip_special_tokens=True)
-    reply = reply.replace(r'<|endoftext|>', '')
-    if model_name.lower().startswith('galactica'):
-        reply = fix_galactica(reply)
-        return reply, reply, generate_basic_html(reply)
-    elif model_name.lower().startswith('gpt4chan'):
-        reply = fix_gpt4chan(reply)
-        return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
-    else:
-        return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
+        cuda = ".cuda()" if args.cpu else ""
+        if eos_token is None:
+            output = eval(f"model.generate(input_ids, {preset}){cuda}")
+        else:
+            n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
+            output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
+
+        reply = tokenizer.decode(output[0], skip_special_tokens=True)
+        reply = reply.replace(r'<|endoftext|>', '')
+        question = reply
+        if model_name.lower().startswith('galactica'):
+            reply = fix_galactica(reply)
+            yield reply, reply, generate_basic_html(reply)
+        elif model_name.lower().startswith('gpt4chan'):
+            reply = fix_gpt4chan(reply)
+            yield reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
+        else:
+            yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
 
 # Choosing the default model
 if args.model is not None:
@@ -205,20 +208,20 @@ if args.notebook:
             with gr.Column():
                 preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
 
-        btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=True, api_name="textgen")
-        textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=True)
+        btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False, api_name="textgen")
+        textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False)
 elif args.chat or args.cai_chat:
     history = []
 
     # This gets the new line characters right.
-    def chat_response_cleaner(text):
+    def clean_chat_message(text):
         text = text.replace('\n', '\n\n')
         text = re.sub(r"\n{3,}", "\n\n", text)
         text = text.strip()
         return text
 
-    def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
-        text = chat_response_cleaner(text)
+    def generate_chat_prompt(text, tokens, name1, name2, context):
+        text = clean_chat_message(text)
 
         rows = [f"{context}\n\n"]
         i = len(history)-1
@@ -234,26 +237,42 @@ elif args.chat or args.cai_chat:
             rows.pop(1)
 
         question = ''.join(rows)
+        return question
 
-        if check:
-            reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0]
-            idx = reply.rfind(question[-1024:])
-            reply = reply[idx+min(1024, len(question)):].split('\n')[0].strip()
-        else:
-            reply = generate_reply(question, tokens, inference_settings, selected_model)[0]
-            idx = reply.rfind(question[-1024:])
-            reply = reply[idx+min(1024, len(question)):]
-            idx = reply.find(f"\n{name1}:")
-            if idx != -1:
-                reply = reply[:idx]
-            reply = chat_response_cleaner(reply)
-
-        history.append((text, reply))
-        return history
+    def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
+        history.append(['', ''])
+        question = generate_chat_prompt(text, tokens, name1, name2, context)
+        eos_token = '\n' if check else None
+        for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
+            reply = i[0]
+
+            if check:
+                idx = reply.rfind(question[-1024:])
+                reply = reply[idx+min(1024, len(question)):].split('\n')[0].strip()
+            else:
+                idx = reply.rfind(question[-1024:])
+                reply = reply[idx+min(1024, len(question)):]
+                idx = reply.find(f"\n{name1}:")
+                if idx != -1:
+                    reply = reply[:idx]
+                reply = clean_chat_message(reply)
+
+            history[-1] = [text, reply]
+
+            # Prevent the chat log from flashing if something like "\nYo" is generated just
+            # before "\nYou:" is completed
+            tmp = f"\n{name1}:"
+            found = False
+            for j in range(1, len(tmp)):
+                if reply[-j:] == tmp[:j]:
+                    found = True
+
+            if not found:
+                yield history
 
     def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
-        history = chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check)
-        return generate_chat_html(history, name1, name2)
+        for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
+            yield generate_chat_html(history, name1, name2)
 
     def remove_last_message(name1, name2):
         history.pop()
@@ -305,13 +324,13 @@ elif args.chat or args.cai_chat:
             check = gr.Checkbox(value=settings['stop_at_newline'], label='Stop generating at new line character?')
 
         if args.cai_chat:
-            btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen")
-            textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True)
-            btn2.click(clear_html, [], display1, show_progress=False)
+            btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen")
+            textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False)
+            btn2.click(clear_html, [], display1, show_progress=True)
         else:
-            btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen")
-            textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True)
-            btn2.click(lambda x: "", display1, display1)
+            btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen")
+            textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False)
+            btn2.click(lambda x: "", display1, display1, show_progress=True)
 
         btn2.click(clear)
         btn3.click(remove_last_message, [name1, name2], display1, show_progress=False)
@@ -320,8 +339,9 @@ elif args.chat or args.cai_chat:
 else:
 
     def continue_wrapper(question, tokens, inference_settings, selected_model):
-        a, b, c = generate_reply(question, tokens, inference_settings, selected_model)
-        return a, a, b, c
+        for i in generate_reply(question, tokens, inference_settings, selected_model):
+            a, b, c = i
+            yield a, a, b, c
 
     with gr.Blocks(css=css, analytics_enabled=False) as interface:
         gr.Markdown(description)
@@ -341,10 +361,11 @@ else:
                 with gr.Tab('HTML'):
                     html = gr.HTML()
 
-        btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True, api_name="textgen")
-        cont.click(continue_wrapper, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, textbox, markdown, html], show_progress=True)
-        textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True)
+        btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False, api_name="textgen")
+        cont.click(continue_wrapper, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, textbox, markdown, html], show_progress=False)
+        textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False)
 
+interface.queue()
 if args.no_listen:
     interface.launch(share=False)
 else: