Просмотр исходного кода

Enable extensions in all modes, not just chat

oobabooga 3 лет назад
Родитель
Сommit
b6d01bb704
1 измененных файлов с 41 добавлено и 30 удалено
  1. 41 30
      server.py

+ 41 - 30
server.py

@@ -241,6 +241,40 @@ def apply_extensions(text, typ):
                 text = eval(f"{ext_string}.output_modifier(text)")
                 text = eval(f"{ext_string}.output_modifier(text)")
     return text
     return text
 
 
+def update_extensions_parameters(*kwargs):
+    i = 0
+    for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
+        if extension_state[ext][0] == True:
+            params = eval(f"extensions.{ext}.script.params")
+            for param in params:
+                if len(kwargs) >= i+1:
+                    params[param] = eval(f"kwargs[{i}]")
+                    i += 1
+
+def create_extensions_block():
+    extensions_ui_elements = []
+    default_values = []
+    gr.Markdown('## Extensions parameters')
+    for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
+        if extension_state[ext][0] == True:
+            params = eval(f"extensions.{ext}.script.params")
+            for param in params:
+                _id = f"{ext}-{param}"
+                default_value = settings[_id] if _id in settings else params[param]
+                default_values.append(default_value)
+                if type(params[param]) == str:
+                    extensions_ui_elements.append(gr.Textbox(value=default_value, label=f"{ext}-{param}"))
+                elif type(params[param]) in [int, float]:
+                    extensions_ui_elements.append(gr.Number(value=default_value, label=f"{ext}-{param}"))
+                elif type(params[param]) == bool:
+                    extensions_ui_elements.append(gr.Checkbox(value=default_value, label=f"{ext}-{param}"))
+
+    update_extensions_parameters(*default_values)
+    btn_extensions = gr.Button("Apply")
+    btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], [])
+    return extensions_ui_elements, btn_extensions
+
+
 def get_available_models():
 def get_available_models():
     return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
     return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
 
 
@@ -606,36 +640,7 @@ if args.chat or args.cai_chat:
                 upload_img_tavern = gr.File(type='binary')
                 upload_img_tavern = gr.File(type='binary')
 
 
         if args.extensions is not None:
         if args.extensions is not None:
-            extensions_ui_elements = []
-            default_values = []
-            gr.Markdown('## Extensions parameters')
-            for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
-                if extension_state[ext][0] == True:
-                    params = eval(f"extensions.{ext}.script.params")
-                    for param in params:
-                        _id = f"{ext}-{param}"
-                        default_value = settings[_id] if _id in settings else params[param]
-                        default_values.append(default_value)
-                        if type(params[param]) == str:
-                            extensions_ui_elements.append(gr.Textbox(value=default_value, label=f"{ext}-{param}"))
-                        elif type(params[param]) in [int, float]:
-                            extensions_ui_elements.append(gr.Number(value=default_value, label=f"{ext}-{param}"))
-                        elif type(params[param]) == bool:
-                            extensions_ui_elements.append(gr.Checkbox(value=default_value, label=f"{ext}-{param}"))
-
-            def update_extensions_parameters(*kwargs):
-                i = 0
-                for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
-                    if extension_state[ext][0] == True:
-                        params = eval(f"extensions.{ext}.script.params")
-                        for param in params:
-                            if len(kwargs) >= i+1:
-                                params[param] = eval(f"kwargs[{i}]")
-                                i += 1
-
-            update_extensions_parameters(*default_values)
-            btn_extensions = gr.Button("Apply")
-            btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], [])
+            extensions_ui_elements, btn_extensions = create_extensions_block()
 
 
         input_params = [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check, history_size_slider]
         input_params = [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check, history_size_slider]
         if args.cai_chat:
         if args.cai_chat:
@@ -689,6 +694,9 @@ elif args.notebook:
                     preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Generation parameters preset')
                     preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Generation parameters preset')
                     create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
                     create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
 
 
+        if args.extensions is not None:
+            extensions_ui_elements, btn_extensions = create_extensions_block()
+
         gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")
         gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")
         gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream)
         gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream)
         stop.click(None, None, None, cancels=[gen_event, gen_event2])
         stop.click(None, None, None, cancels=[gen_event, gen_event2])
@@ -712,6 +720,9 @@ else:
                         cont = gr.Button("Continue")
                         cont = gr.Button("Continue")
                     with gr.Column():
                     with gr.Column():
                         stop = gr.Button("Stop")
                         stop = gr.Button("Stop")
+                if args.extensions is not None:
+                    extensions_ui_elements, btn_extensions = create_extensions_block()
+
             with gr.Column():
             with gr.Column():
                 with gr.Tab('Raw'):
                 with gr.Tab('Raw'):
                     output_textbox = gr.Textbox(lines=15, label='Output')
                     output_textbox = gr.Textbox(lines=15, label='Output')