Browse Source

Move function to extensions module

oobabooga 2 years ago
parent
commit
c87800341c
2 changed files with 28 additions and 28 deletions
  1. 24 0
      modules/extensions.py
  2. 4 28
      server.py

+ 24 - 0
modules/extensions.py

@@ -1,5 +1,6 @@
 import extensions
 import modules.shared as shared
+import gradio as gr
 
 extension_state = {}
 available_extensions = []
@@ -38,3 +39,26 @@ def update_extensions_parameters(*kwargs):
 
 def get_params(name):
     return eval(f"extensions.{name}.script.params")
+
+def create_extensions_block():
+    extensions_ui_elements = []
+    default_values = []
+    if not (shared.args.chat or shared.args.cai_chat):
+        gr.Markdown('## Extensions parameters')
+    for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
+        if extension_state[ext][0] == True:
+            params = get_params(ext)
+            for param in params:
+                _id = f"{ext}-{param}"
+                default_value = shared.settings[_id] if _id in shared.settings else params[param]
+                default_values.append(default_value)
+                if type(params[param]) == str:
+                    extensions_ui_elements.append(gr.Textbox(value=default_value, label=f"{ext}-{param}"))
+                elif type(params[param]) in [int, float]:
+                    extensions_ui_elements.append(gr.Number(value=default_value, label=f"{ext}-{param}"))
+                elif type(params[param]) == bool:
+                    extensions_ui_elements.append(gr.Checkbox(value=default_value, label=f"{ext}-{param}"))
+
+    update_extensions_parameters(*default_values)
+    btn_extensions = gr.Button("Apply")
+    btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], [])

+ 4 - 28
server.py

@@ -14,7 +14,6 @@ import modules.chat as chat
 import modules.extensions as extensions_module
 import modules.shared as shared
 import modules.ui as ui
-from modules.extensions import extension_state, load_extensions, update_extensions_parameters
 from modules.html_generator import generate_chat_html
 from modules.models import load_model, load_soft_prompt
 from modules.text_generation import generate_reply
@@ -95,29 +94,6 @@ def upload_soft_prompt(file):
 
     return name
 
-def create_extensions_block():
-    extensions_ui_elements = []
-    default_values = []
-    if not (shared.args.chat or shared.args.cai_chat):
-        gr.Markdown('## Extensions parameters')
-    for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
-        if extension_state[ext][0] == True:
-            params = extensions_module.get_params(ext)
-            for param in params:
-                _id = f"{ext}-{param}"
-                default_value = shared.settings[_id] if _id in shared.settings else params[param]
-                default_values.append(default_value)
-                if type(params[param]) == str:
-                    extensions_ui_elements.append(gr.Textbox(value=default_value, label=f"{ext}-{param}"))
-                elif type(params[param]) in [int, float]:
-                    extensions_ui_elements.append(gr.Number(value=default_value, label=f"{ext}-{param}"))
-                elif type(params[param]) == bool:
-                    extensions_ui_elements.append(gr.Checkbox(value=default_value, label=f"{ext}-{param}"))
-
-    update_extensions_parameters(*default_values)
-    btn_extensions = gr.Button("Apply")
-    btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], [])
-
 def create_settings_menus():
     generate_params = load_preset_values(shared.settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', return_dict=True)
 
@@ -176,7 +152,7 @@ available_softprompts = get_available_softprompts()
 
 extensions_module.available_extensions = get_available_extensions()
 if shared.args.extensions is not None:
-    load_extensions()
+    extensions_module.load_extensions()
 
 # Choosing the default model
 if shared.args.model is not None:
@@ -279,7 +255,7 @@ if shared.args.chat or shared.args.cai_chat:
 
         if shared.args.extensions is not None:
             with gr.Tab("Extensions"):
-                create_extensions_block()
+                extensions_module.create_extensions_block()
 
         input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size_slider]
         if shared.args.picture:
@@ -340,7 +316,7 @@ elif shared.args.notebook:
         preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus()
 
         if shared.args.extensions is not None:
-            create_extensions_block()
+            extensions_module.create_extensions_block()
 
         gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=shared.args.no_stream, api_name="textgen"))
         gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=shared.args.no_stream))
@@ -362,7 +338,7 @@ else:
 
                 preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus()
                 if shared.args.extensions is not None:
-                    create_extensions_block()
+                    extensions_module.create_extensions_block()
 
             with gr.Column():
                 with gr.Tab('Raw'):