Pārlūkot izejas kodu

Merge pull request #433 from mayaeary/fix/api-reload

Fix api extension duplicating
oobabooga 2 gadi atpakaļ
vecāks
revīzija
c14e598f14
2 mainītis faili ar 19 papildinājumiem un 4 dzēšanām
  1. 4 3
      extensions/api/script.py
  2. 15 1
      modules/extensions.py

+ 4 - 3
extensions/api/script.py

@@ -1,8 +1,9 @@
+import json
 from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
 from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
 from threading import Thread
 from threading import Thread
+
 from modules import shared
 from modules import shared
-from modules.text_generation import generate_reply, encode
-import json
+from modules.text_generation import encode, generate_reply
 
 
 params = {
 params = {
     'port': 5000,
     'port': 5000,
@@ -87,5 +88,5 @@ def run_server():
         print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
         print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
     server.serve_forever()
     server.serve_forever()
 
 
-def ui():
+def setup():
     Thread(target=run_server, daemon=True).start()
     Thread(target=run_server, daemon=True).start()

+ 15 - 1
modules/extensions.py

@@ -7,6 +7,7 @@ import modules.shared as shared
 
 
 state = {}
 state = {}
 available_extensions = []
 available_extensions = []
+setup_called = False
 
 
 def load_extensions():
 def load_extensions():
     global state
     global state
@@ -39,6 +40,8 @@ def apply_extensions(text, typ):
     return text
     return text
 
 
 def create_extensions_block():
 def create_extensions_block():
+    global setup_called
+
     # Updating the default values
     # Updating the default values
     for extension, name in iterator():
     for extension, name in iterator():
         if hasattr(extension, 'params'):
         if hasattr(extension, 'params'):
@@ -47,8 +50,19 @@ def create_extensions_block():
                 if _id in shared.settings:
                 if _id in shared.settings:
                     extension.params[param] = shared.settings[_id]
                     extension.params[param] = shared.settings[_id]
 
 
+    should_display_ui = False
+
+    # Running setup function
+    if not setup_called:
+        for extension, name in iterator():
+            if hasattr(extension, "setup"):
+                extension.setup()
+            if hasattr(extension, "ui"):
+                should_display_ui = True
+        setup_called = True
+
     # Creating the extension ui elements
     # Creating the extension ui elements
-    if len(state) > 0:
+    if should_display_ui:
         with gr.Box(elem_id="extensions"):
         with gr.Box(elem_id="extensions"):
             gr.Markdown("Extensions")
             gr.Markdown("Extensions")
             for extension, name in iterator():
             for extension, name in iterator():