oobabooga 2 anni fa
parent
commit
ffb898608b
1 ha cambiato i file con 7 aggiunte e 15 eliminazioni
  1. 7 15
      server.py

+ 7 - 15
server.py

@@ -206,8 +206,8 @@ title ='Text generation web UI'
 description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
 description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
 suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
 suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
 
 
-if shared.args.chat or shared.args.cai_chat:
-    with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
+with gr.Blocks(css=ui.css if not any((shared.args.chat, shared.args.cai_chat)) else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
+    if shared.args.chat or shared.args.cai_chat:
         with gr.Tab("Text generation", elem_id="main"):
         with gr.Tab("Text generation", elem_id="main"):
             if shared.args.cai_chat:
             if shared.args.cai_chat:
                 shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
                 shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
@@ -276,9 +276,6 @@ if shared.args.chat or shared.args.cai_chat:
 
 
             create_settings_menus(default_preset)
             create_settings_menus(default_preset)
 
 
-        if shared.args.extensions is not None:
-            extensions_module.create_extensions_block()
-
         function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
         function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
         shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
         shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
 
 
@@ -325,8 +322,7 @@ if shared.args.chat or shared.args.cai_chat:
         shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
         shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
         shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
         shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
 
 
-elif shared.args.notebook:
-    with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
+    elif shared.args.notebook:
         with gr.Tab("Text generation", elem_id="main"):
         with gr.Tab("Text generation", elem_id="main"):
             with gr.Tab('Raw'):
             with gr.Tab('Raw'):
                 shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25)
                 shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25)
@@ -344,9 +340,6 @@ elif shared.args.notebook:
         with gr.Tab("Settings", elem_id="settings"):
         with gr.Tab("Settings", elem_id="settings"):
             create_settings_menus(default_preset)
             create_settings_menus(default_preset)
 
 
-        if shared.args.extensions is not None:
-            extensions_module.create_extensions_block()
-
         shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
         shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
         output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
         output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
         gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
         gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
@@ -354,8 +347,7 @@ elif shared.args.notebook:
         shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
         shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
         shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
         shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
 
 
-else:
-    with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
+    else:
         with gr.Tab("Text generation", elem_id="main"):
         with gr.Tab("Text generation", elem_id="main"):
             with gr.Row():
             with gr.Row():
                 with gr.Column():
                 with gr.Column():
@@ -380,9 +372,6 @@ else:
         with gr.Tab("Settings", elem_id="settings"):
         with gr.Tab("Settings", elem_id="settings"):
             create_settings_menus(default_preset)
             create_settings_menus(default_preset)
 
 
-        if shared.args.extensions is not None:
-            extensions_module.create_extensions_block()
-
         shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
         shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
         output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
         output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
         gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
         gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
@@ -391,6 +380,9 @@ else:
         shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
         shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
         shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
         shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
 
 
+    if shared.args.extensions is not None:
+        extensions_module.create_extensions_block()
+
 shared.gradio['interface'].queue()
 shared.gradio['interface'].queue()
 if shared.args.listen:
 if shared.args.listen:
     shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
     shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)