Просмотр исходного кода

Add bot prefix modifier option in extensions

oobabooga 3 лет назад
Родитель
Сommit
e5ff4ddfc8
3 измененных файлов с 25 добавлено и 5 удалено
  1. 10 0
      extensions/example/script.py
  2. 9 0
      extensions/google_translate/script.py
  3. 6 5
      server.py

+ 10 - 0
extensions/example/script.py

@@ -1,5 +1,6 @@
 params = {
 params = {
     "input suffix": " *I say as I make a funny face*",
     "input suffix": " *I say as I make a funny face*",
+    "bot prefix": " *I speak in a cute way*",
 }
 }
 
 
 def input_modifier(string):
 def input_modifier(string):
@@ -16,3 +17,12 @@ def output_modifier(string):
     """
     """
 
 
     return string
     return string
+
+def bot_prefix_modifier(string):
+    """
+    This function is only applied in chat mode. It modifies
+    the prefix text for the Bot and can be used to bias its
+    behavior.
+    """
+
+    return string + params["bot prefix"]

+ 9 - 0
extensions/google_translate/script.py

@@ -20,3 +20,12 @@ def output_modifier(string):
     """
     """
 
 
     return translator.translate(string, src="en", dest=params['language string']).text
     return translator.translate(string, src="en", dest=params['language string']).text
+
+def bot_prefix_modifier(string):
+    """
+    This function is only applied in chat mode. It modifies
+    the prefix text for the Bot and can be used to bias its
+    behavior.
+    """
+
+    return string

+ 6 - 5
server.py

@@ -235,10 +235,12 @@ def apply_extensions(text, typ):
     for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
     for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
         if extension_state[ext][0] == True:
         if extension_state[ext][0] == True:
             ext_string = f"extensions.{ext}.script"
             ext_string = f"extensions.{ext}.script"
-            if typ == "input":
+            if typ == "input" and hasattr(eval(ext_string), "input_modifier"):
                 text = eval(f"{ext_string}.input_modifier(text)")
                 text = eval(f"{ext_string}.input_modifier(text)")
-            else:
+            elif typ == "output" and hasattr(eval(ext_string), "output_modifier"):
                 text = eval(f"{ext_string}.output_modifier(text)")
                 text = eval(f"{ext_string}.output_modifier(text)")
+            elif typ == "bot_prefix" and hasattr(eval(ext_string), "bot_prefix_modifier"):
+                text = eval(f"{ext_string}.bot_prefix_modifier(text)")
     return text
     return text
 
 
 def update_extensions_parameters(*kwargs):
 def update_extensions_parameters(*kwargs):
@@ -274,7 +276,6 @@ def create_extensions_block():
     btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], [])
     btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], [])
     return extensions_ui_elements, btn_extensions
     return extensions_ui_elements, btn_extensions
 
 
-
 def get_available_models():
 def get_available_models():
     return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
     return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
 
 
@@ -353,7 +354,7 @@ if args.chat or args.cai_chat:
             if history_size != 0 and count >= history_size:
             if history_size != 0 and count >= history_size:
                 break
                 break
         rows.append(f"{name1}: {text}\n")
         rows.append(f"{name1}: {text}\n")
-        rows.append(f"{name2}:")
+        rows.append(apply_extensions(f"{name2}:", "bot_prefix"))
 
 
         while len(rows) > 3 and len(encode(''.join(rows), tokens)[0]) >= 2048-tokens:
         while len(rows) > 3 and len(encode(''.join(rows), tokens)[0]) >= 2048-tokens:
             rows.pop(1)
             rows.pop(1)
@@ -376,7 +377,7 @@ if args.chat or args.cai_chat:
             idx = [m.start() for m in re.finditer(f"(^|\n){name2}:", reply)]
             idx = [m.start() for m in re.finditer(f"(^|\n){name2}:", reply)]
             idx = idx[len(previous_idx)-1]
             idx = idx[len(previous_idx)-1]
 
 
-            reply = reply[idx + len(f"\n{name2}:"):]
+            reply = reply[idx + 1 + len(apply_extensions(f"{name2}:", "bot_prefix")):]
             if check:
             if check:
                 reply = reply.split('\n')[0].strip()
                 reply = reply.split('\n')[0].strip()
             else:
             else: