Jelajahi Sumber

Pygmalion: add checkbox for choosing whether to stop at newline or not

oobabooga 3 tahun lalu
induk
melakukan
ecb2cc2194
1 mengubah file dengan 16 tambahan dan 5 penghapusan
  1. 16 5
      server.py

+ 16 - 5
server.py

@@ -188,7 +188,7 @@ if args.notebook:
 elif args.chat:
     history = []
 
-    def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context):
+    def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
         question = context+'\n\n'
         for i in range(len(history)):
             question += f"{name1}: {history[i][0][3:-5].strip()}\n"
@@ -196,8 +196,16 @@ elif args.chat:
         question += f"{name1}: {text.strip()}\n"
         question += f"{name2}:"
 
-        reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0]
-        reply = reply[len(question):].split('\n')[0].strip()
+        if check:
+            reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0]
+            reply = reply[len(question):].split('\n')[0].strip()
+        else:
+            reply = generate_reply(question, tokens, inference_settings, selected_model)[0]
+            reply = reply[len(question):].strip()
+            idx = reply.find(f"\n{name1}:")
+            if idx != -1:
+                reply = reply[:idx]
+
         history.append((text, reply))
         return history
 
@@ -228,14 +236,17 @@ elif args.chat:
                 name1 = gr.Textbox(value=name1_str, lines=1, label='Your name')
                 name2 = gr.Textbox(value=name2_str, lines=1, label='Bot\'s name')
                 context = gr.Textbox(value=context_str, lines=2, label='Context')
+                with gr.Row():
+                    check = gr.Checkbox(value=True, label='Stop generating at new line character?')
+
             with gr.Column():
                 display1 = gr.Chatbot()
                 textbox = gr.Textbox(lines=2, label='Input')
                 btn = gr.Button("Generate")
                 btn2 = gr.Button("Clear history")
 
-        btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context], display1, show_progress=True, api_name="textgen")
-        textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context], display1, show_progress=True)
+        btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen")
+        textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True)
         btn2.click(clear)
         btn.click(lambda x: "", textbox, textbox, show_progress=False)
         textbox.submit(lambda x: "", textbox, textbox, show_progress=False)