Sfoglia il codice sorgente

Add a header bar and redesign the interface (#293)

oobabooga 2 anni fa
parent
commit
1413931705
3 ha cambiato i file con 102 aggiunte e 74 eliminazioni
  1. 1 1
      extensions/gallery/script.py
  2. 9 0
      modules/ui.py
  3. 92 73
      server.py

+ 1 - 1
extensions/gallery/script.py

@@ -76,7 +76,7 @@ def generate_html():
     return container_html
 
 def ui():
-    with gr.Accordion("Character gallery"):
+    with gr.Accordion("Character gallery", open=False):
         update = gr.Button("Refresh")
         gallery = gr.HTML(value=generate_html())
     update.click(generate_html, [], gallery)

+ 9 - 0
modules/ui.py

@@ -38,6 +38,9 @@ svg {
 ol li p, ul li p {
     display: inline-block;
 }
+#main, #settings, #extensions, #chat-settings {
+  border: 0;
+}
 """
 
 chat_css = """
@@ -64,6 +67,12 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
 }
 """
 
+page_js = """
+document.getElementById("main").parentNode.childNodes[0].style = "border: none; background-color: #8080802b; margin-bottom: 40px"
+document.getElementById("main").parentNode.style = "padding: 0; margin: 0"
+document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0"
+"""
+
 class ToolButton(gr.Button, gr.components.FormComponent):
     """Small button with single emoji as text, fits inside gradio forms"""
 

+ 92 - 73
server.py

@@ -101,9 +101,7 @@ def upload_soft_prompt(file):
 
     return name
 
-def create_settings_menus(default_preset):
-    generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
-
+def create_model_and_preset_menus():
     with gr.Row():
         with gr.Column():
             with gr.Row():
@@ -114,7 +112,11 @@ def create_settings_menus(default_preset):
                 shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
                 ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
 
-    with gr.Accordion('Custom generation parameters', open=False, elem_id='accordion'):
+def create_settings_menus(default_preset):
+    generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
+
+    with gr.Box():
+        gr.Markdown('Custom generation parameters')
         with gr.Row():
             with gr.Column():
                 shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
@@ -128,9 +130,11 @@ def create_settings_menus(default_preset):
                 shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
         shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
 
+    with gr.Box():
         gr.Markdown('Contrastive search:')
         shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
 
+    with gr.Box():
         gr.Markdown('Beam search (uses a lot of VRAM):')
         with gr.Row():
             with gr.Column():
@@ -139,7 +143,8 @@ def create_settings_menus(default_preset):
                 shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
         shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
 
-    with gr.Accordion('Soft prompt', open=False, elem_id='accordion'):
+    with gr.Box():
+        gr.Markdown('Soft prompt')
         with gr.Row():
             shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
             ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button')
@@ -202,26 +207,41 @@ suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
 
 if shared.args.chat or shared.args.cai_chat:
     with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
-        if shared.args.cai_chat:
-            shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
-        else:
-            shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
-        shared.gradio['textbox'] = gr.Textbox(label='Input')
-        with gr.Row():
-            shared.gradio['Stop'] = gr.Button('Stop')
-            shared.gradio['Generate'] = gr.Button('Generate')
-        with gr.Row():
-            shared.gradio['Impersonate'] = gr.Button('Impersonate')
-            shared.gradio['Regenerate'] = gr.Button('Regenerate')
-        with gr.Row():
-            shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
-            shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
-            shared.gradio['Remove last'] = gr.Button('Remove last')
-
-            shared.gradio['Clear history'] = gr.Button('Clear history')
-            shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
-            shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
-        with gr.Tab('Chat settings'):
+        with gr.Tab("Text generation", elem_id="main"):
+            if shared.args.cai_chat:
+                shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
+            else:
+                shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
+            shared.gradio['textbox'] = gr.Textbox(label='Input')
+            with gr.Row():
+                shared.gradio['Stop'] = gr.Button('Stop')
+                shared.gradio['Generate'] = gr.Button('Generate')
+            with gr.Row():
+                shared.gradio['Impersonate'] = gr.Button('Impersonate')
+                shared.gradio['Regenerate'] = gr.Button('Regenerate')
+            with gr.Row():
+                shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
+                shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
+                shared.gradio['Remove last'] = gr.Button('Remove last')
+
+                shared.gradio['Clear history'] = gr.Button('Clear history')
+                shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
+                shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
+
+            create_model_and_preset_menus()
+
+            with gr.Box():
+                with gr.Row():
+                    with gr.Column():
+                        shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
+                        shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
+                    with gr.Column():
+                        shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
+
+            if shared.args.extensions is not None:
+                extensions_module.create_extensions_block()
+
+        with gr.Tab("Chat settings", elem_id="chat-settings"):
             shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
             shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
             shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context')
@@ -255,21 +275,11 @@ if shared.args.chat or shared.args.cai_chat:
                 with gr.Tab('Upload TavernAI Character Card'):
                     shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
 
-        with gr.Tab('Generation settings'):
-            with gr.Row():
-                with gr.Column():
-                    shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
-                with gr.Column():
-                    shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
-                shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
+        with gr.Tab("Settings", elem_id="settings"):
             create_settings_menus(default_preset)
 
-        shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
-        if shared.args.extensions is not None:
-            with gr.Tab('Extensions'):
-                extensions_module.create_extensions_block()
-
         function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
+        shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
 
         gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
         gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
@@ -310,58 +320,66 @@ if shared.args.chat or shared.args.cai_chat:
         shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
         shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
 
+        shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.page_js}}}")
         shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
         shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
 
 elif shared.args.notebook:
     with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
-        gr.Markdown(description)
-        with gr.Tab('Raw'):
-            shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=23)
-        with gr.Tab('Markdown'):
-            shared.gradio['markdown'] = gr.Markdown()
-        with gr.Tab('HTML'):
-            shared.gradio['html'] = gr.HTML()
-
-        shared.gradio['Generate'] = gr.Button('Generate')
-        shared.gradio['Stop'] = gr.Button('Stop')
-        shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
-
-        create_settings_menus(default_preset)
-        if shared.args.extensions is not None:
-            extensions_module.create_extensions_block()
+        with gr.Tab("Text generation", elem_id="main"):
+            with gr.Tab('Raw'):
+                shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25)
+            with gr.Tab('Markdown'):
+                shared.gradio['markdown'] = gr.Markdown()
+            with gr.Tab('HTML'):
+                shared.gradio['html'] = gr.HTML()
+
+            with gr.Row():
+                shared.gradio['Stop'] = gr.Button('Stop')
+                shared.gradio['Generate'] = gr.Button('Generate')
+            shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
+
+            create_model_and_preset_menus()
+            if shared.args.extensions is not None:
+                extensions_module.create_extensions_block()
+
+        with gr.Tab("Settings", elem_id="settings"):
+            create_settings_menus(default_preset)
 
         shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
         output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
         gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
         gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
         shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
+        shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.page_js}}}")
 
 else:
     with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
-        gr.Markdown(description)
-        with gr.Row():
-            with gr.Column():
-                shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input')
-                shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
-                shared.gradio['Generate'] = gr.Button('Generate')
-                with gr.Row():
-                    with gr.Column():
-                        shared.gradio['Continue'] = gr.Button('Continue')
-                    with gr.Column():
-                        shared.gradio['Stop'] = gr.Button('Stop')
+        with gr.Tab("Text generation", elem_id="main"):
+            with gr.Row():
+                with gr.Column():
+                    shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input')
+                    shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
+                    shared.gradio['Generate'] = gr.Button('Generate')
+                    with gr.Row():
+                        with gr.Column():
+                            shared.gradio['Continue'] = gr.Button('Continue')
+                        with gr.Column():
+                            shared.gradio['Stop'] = gr.Button('Stop')
 
-                create_settings_menus(default_preset)
-                if shared.args.extensions is not None:
-                    extensions_module.create_extensions_block()
+                    create_model_and_preset_menus()
+                    if shared.args.extensions is not None:
+                        extensions_module.create_extensions_block()
 
-            with gr.Column():
-                with gr.Tab('Raw'):
-                    shared.gradio['output_textbox'] = gr.Textbox(lines=15, label='Output')
-                with gr.Tab('Markdown'):
-                    shared.gradio['markdown'] = gr.Markdown()
-                with gr.Tab('HTML'):
-                    shared.gradio['html'] = gr.HTML()
+                with gr.Column():
+                    with gr.Tab('Raw'):
+                        shared.gradio['output_textbox'] = gr.Textbox(lines=25, label='Output')
+                    with gr.Tab('Markdown'):
+                        shared.gradio['markdown'] = gr.Markdown()
+                    with gr.Tab('HTML'):
+                        shared.gradio['html'] = gr.HTML()
+        with gr.Tab("Settings", elem_id="settings"):
+            create_settings_menus(default_preset)
 
         shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
         output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
@@ -369,6 +387,7 @@ else:
         gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
         gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
         shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
+        shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.page_js}}}")
 
 shared.gradio['interface'].queue()
 if shared.args.listen: