Przeglądaj źródła

Add a "stop" button

oobabooga 3 lat temu
rodzic
commit
3cb30bed0a
1 zmienionych plików z 43 dodań i 40 usunięć
  1. 43 40
      server.py

+ 43 - 40
server.py

@@ -191,27 +191,7 @@ else:
 description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
 description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
 css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem}"
 css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem}"
 
 
-if args.notebook:
-    with gr.Blocks(css=css, analytics_enabled=False) as interface:
-        gr.Markdown(description)
-        with gr.Tab('Raw'):
-            textbox = gr.Textbox(value=default_text, lines=23)
-        with gr.Tab('Markdown'):
-            markdown = gr.Markdown()
-        with gr.Tab('HTML'):
-            html = gr.HTML()
-        btn = gr.Button("Generate")
-
-        length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
-        with gr.Row():
-            with gr.Column():
-                model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
-            with gr.Column():
-                preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
-
-        btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False, api_name="textgen")
-        textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False)
-elif args.chat or args.cai_chat:
+if args.chat or args.cai_chat:
     history = []
     history = []
 
 
     # This gets the new line characters right.
     # This gets the new line characters right.
@@ -311,10 +291,9 @@ elif args.chat or args.cai_chat:
         textbox = gr.Textbox(lines=2, label='Input')
         textbox = gr.Textbox(lines=2, label='Input')
         btn = gr.Button("Generate")
         btn = gr.Button("Generate")
         with gr.Row():
         with gr.Row():
-            with gr.Column():
-                btn3 = gr.Button("Remove last message")
-            with gr.Column():
-                btn2 = gr.Button("Clear history")
+            btn2 = gr.Button("Clear history")
+            stop = gr.Button("Stop")
+            btn3 = gr.Button("Remove last message")
 
 
         length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
         length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
         with gr.Row():
         with gr.Row():
@@ -330,25 +309,44 @@ elif args.chat or args.cai_chat:
             check = gr.Checkbox(value=settings['stop_at_newline'], label='Stop generating at new line character?')
             check = gr.Checkbox(value=settings['stop_at_newline'], label='Stop generating at new line character?')
 
 
         if args.cai_chat:
         if args.cai_chat:
-            btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen")
-            textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False)
-            btn2.click(clear_html, [], display1, show_progress=True)
+            gen_event = btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen")
+            gen_event2 = textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False)
+            btn2.click(clear_html, [], display1, show_progress=False)
         else:
         else:
-            btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen")
-            textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False)
-            btn2.click(lambda x: "", display1, display1, show_progress=True)
+            gen_event = btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen")
+            gen_event2 = textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False)
+            btn2.click(lambda x: "", display1, display1, show_progress=False)
 
 
         btn2.click(clear)
         btn2.click(clear)
         btn3.click(remove_last_message, [name1, name2], display1, show_progress=False)
         btn3.click(remove_last_message, [name1, name2], display1, show_progress=False)
         btn.click(lambda x: "", textbox, textbox, show_progress=False)
         btn.click(lambda x: "", textbox, textbox, show_progress=False)
         textbox.submit(lambda x: "", textbox, textbox, show_progress=False)
         textbox.submit(lambda x: "", textbox, textbox, show_progress=False)
-else:
+        stop.click(None, None, None, cancels=[gen_event, gen_event2])
 
 
-    def continue_wrapper(question, tokens, inference_settings, selected_model):
-        for i in generate_reply(question, tokens, inference_settings, selected_model):
-            a, b, c = i
-            yield a, a, b, c
+elif args.notebook:
+    with gr.Blocks(css=css, analytics_enabled=False) as interface:
+        gr.Markdown(description)
+        with gr.Tab('Raw'):
+            textbox = gr.Textbox(value=default_text, lines=23)
+        with gr.Tab('Markdown'):
+            markdown = gr.Markdown()
+        with gr.Tab('HTML'):
+            html = gr.HTML()
+        btn = gr.Button("Generate")
+        stop = gr.Button("Stop")
 
 
+        length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
+        with gr.Row():
+            with gr.Column():
+                model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
+            with gr.Column():
+                preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
+
+        gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False, api_name="textgen")
+        gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False)
+        stop.click(None, None, None, cancels=[gen_event, gen_event2])
+
+else:
     with gr.Blocks(css=css, analytics_enabled=False) as interface:
     with gr.Blocks(css=css, analytics_enabled=False) as interface:
         gr.Markdown(description)
         gr.Markdown(description)
         with gr.Row():
         with gr.Row():
@@ -358,7 +356,11 @@ else:
                 preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
                 preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
                 model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
                 model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
                 btn = gr.Button("Generate")
                 btn = gr.Button("Generate")
-                cont = gr.Button("Continue")
+                with gr.Row():
+                    with gr.Column():
+                        cont = gr.Button("Continue")
+                    with gr.Column():
+                        stop = gr.Button("Stop")
             with gr.Column():
             with gr.Column():
                 with gr.Tab('Raw'):
                 with gr.Tab('Raw'):
                     output_textbox = gr.Textbox(lines=15, label='Output')
                     output_textbox = gr.Textbox(lines=15, label='Output')
@@ -367,9 +369,10 @@ else:
                 with gr.Tab('HTML'):
                 with gr.Tab('HTML'):
                     html = gr.HTML()
                     html = gr.HTML()
 
 
-        btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False, api_name="textgen")
-        cont.click(continue_wrapper, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, textbox, markdown, html], show_progress=False)
-        textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False)
+        gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False, api_name="textgen")
+        gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False)
+        cont_event = cont.click(generate_reply, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False)
+        stop.click(None, None, None, cancels=[gen_event, gen_event2, cont_event])
 
 
 interface.queue()
 interface.queue()
 if args.no_listen:
 if args.no_listen: