Преглед на файлове

Allow for permanent hijacking

oobabooga преди 2 години
родител
ревизия
67623a52b7
променени са 2 файла, в които са добавени 12 реда и са изтрити 8 реда
  1. 6 5
      extensions/send_pictures/script.py
  2. 6 3
      modules/chat.py

+ 6 - 5
extensions/send_pictures/script.py

@@ -9,15 +9,16 @@ from modules.bot_picture import caption_image
 params = {
 }
 
-# If 'state' is True, will hijack the next chatbot wrapper call
-# with a custom input text
+# If 'state' is 'temporary' or 'permanent', will hijack the next
+# chatbot wrapper call with a custom input text and optionally
+# custom output text
 input_hijack = {
-    'state': False,
-    'value': ["", ""]
+    'state': 'off',
+    'value': []
 }
 
 prompt_hijack = {
-    'state': False,
+    'state': 'off',
     'value': ""
 }
 

+ 6 - 3
modules/chat.py

@@ -91,8 +91,9 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
     visible_text = None
     prompt = None
     for extension, _ in extensions_module.iterator():
-        if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True:
-            extension.input_hijack['state'] = False
+        if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] in ['temporary', 'permanent']:
+            if extension.input_hijack['state'] == 'temporary':
+                extension.input_hijack['state'] = 'off'
             values = extension.input_hijack['value']
             if len(values) == 2:
                 text, visible_text = values
@@ -102,7 +103,9 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
                     shared.history['internal'].append([text, reply])
                     shared.history['visible'].append([visible_text, visible_reply])
                 return shared.history['visible']
-        if hasattr(extension, 'prompt_hijack') and extension.prompt_hijack['state'] == True:
+        if hasattr(extension, 'prompt_hijack') and extension.prompt_hijack['state'] in ['temporary', 'permanent']:
+            if extension.prompt_hijack['state'] == 'temporary':
+                extension.prompt_hijack['state'] = 'off'
             prompt = extension.prompt_hijack['value']
                 
     just_started = True