Просмотр исходного кода

Simplify the extensions implementation

oobabooga 2 лет назад
Родитель
Сommit
fe5057f932
1 измененных файлов с 37 добавлено и 41 удалено
  1. 37 41
      modules/extensions.py

+ 37 - 41
modules/extensions.py

@@ -2,62 +2,58 @@ import extensions
 import modules.shared as shared
 import gradio as gr
 
-extension_state = {}
+state = {}
 available_extensions = []
 
 def load_extensions():
-    global extension_state
-    for i,ext in enumerate(shared.args.extensions):
-        if ext in available_extensions:
-            print(f'Loading the extension "{ext}"... ', end='')
-            ext_string = f"extensions.{ext}.script"
-            exec(f"import {ext_string}")
-            extension_state[ext] = [True, i]
+    global state
+    for i, name in enumerate(shared.args.extensions):
+        if name in available_extensions:
+            print(f'Loading the extension "{name}"... ', end='')
+            import_string = f"extensions.{name}.script"
+            exec(f"import {import_string}")
+            state[name] = [True, i]
             print(f'Ok.')
 
+def iterator():
+    for name in sorted(state, key=lambda x : state[x][1]):
+        if state[name][0] == True:
+            yield eval(f"extensions.{name}.script"), name
+
 def apply_extensions(text, typ):
-    for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
-        if extension_state[ext][0] == True:
-            ext_string = f"extensions.{ext}.script"
-            if typ == "input" and hasattr(eval(ext_string), "input_modifier"):
-                text = eval(f"{ext_string}.input_modifier(text)")
-            elif typ == "output" and hasattr(eval(ext_string), "output_modifier"):
-                text = eval(f"{ext_string}.output_modifier(text)")
-            elif typ == "bot_prefix" and hasattr(eval(ext_string), "bot_prefix_modifier"):
-                text = eval(f"{ext_string}.bot_prefix_modifier(text)")
+    for extension, _ in iterator():
+        if typ == "input" and hasattr(extension, "input_modifier"):
+            text = extension.input_modifier(text)
+        elif typ == "output" and hasattr(extension, "output_modifier"):
+            text = extension.output_modifier(text)
+        elif typ == "bot_prefix" and hasattr(extension, "bot_prefix_modifier"):
+            text = extension.bot_prefix_modifier(text)
     return text
 
-def update_extensions_parameters(*kwargs):
+def update_extensions_parameters(*args):
     i = 0
-    for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
-        if extension_state[ext][0] == True:
-            params = eval(f"extensions.{ext}.script.params")
-            for param in params:
-                if len(kwargs) >= i+1:
-                    params[param] = eval(f"kwargs[{i}]")
-                    i += 1
-
-def get_params(name):
-    return eval(f"extensions.{name}.script.params")
+    for extension, _ in iterator():
+        for param in extension.params:
+            if len(args) >= i+1:
+                extension.params[param] = eval(f"args[{i}]")
+                i += 1
 
 def create_extensions_block():
     extensions_ui_elements = []
     default_values = []
     if not (shared.args.chat or shared.args.cai_chat):
         gr.Markdown('## Extensions parameters')
-    for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
-        if extension_state[ext][0] == True:
-            params = get_params(ext)
-            for param in params:
-                _id = f"{ext}-{param}"
-                default_value = shared.settings[_id] if _id in shared.settings else params[param]
-                default_values.append(default_value)
-                if type(params[param]) == str:
-                    extensions_ui_elements.append(gr.Textbox(value=default_value, label=f"{ext}-{param}"))
-                elif type(params[param]) in [int, float]:
-                    extensions_ui_elements.append(gr.Number(value=default_value, label=f"{ext}-{param}"))
-                elif type(params[param]) == bool:
-                    extensions_ui_elements.append(gr.Checkbox(value=default_value, label=f"{ext}-{param}"))
+    for extension, name in iterator():
+        for param in extension.params:
+            _id = f"{name}-{param}"
+            default_value = shared.settings[_id] if _id in shared.settings else extension.params[param]
+            default_values.append(default_value)
+            if type(extension.params[param]) == str:
+                extensions_ui_elements.append(gr.Textbox(value=default_value, label=f"{name}-{param}"))
+            elif type(extension.params[param]) in [int, float]:
+                extensions_ui_elements.append(gr.Number(value=default_value, label=f"{name}-{param}"))
+            elif type(extension.params[param]) == bool:
+                extensions_ui_elements.append(gr.Checkbox(value=default_value, label=f"{name}-{param}"))
 
     update_extensions_parameters(*default_values)
     btn_extensions = gr.Button("Apply")