Przeglądaj źródła

Apply dialogue format in all character fields not just example dialogue (#650)

Alex "mcmonkey" Goodwin 2 lat temu
rodzic
commit
ea97303509
1 zmienionych plików z 10 dodań i 5 usunięć
  1. 10 5
      modules/chat.py

+ 10 - 5
modules/chat.py

@@ -330,6 +330,10 @@ def load_default_history(name1, name2):
         shared.history['internal'] = []
         shared.history['internal'] = []
         shared.history['visible'] = []
         shared.history['visible'] = []
 
 
+def replace_character_names(text, name1, name2):
+    text = text.replace('{{user}}', name1).replace('{{char}}', name2)
+    return text.replace('<USER>', name1).replace('<BOT>', name2)
+
 def build_pygmalion_style_context(data):
 def build_pygmalion_style_context(data):
     context = ""
     context = ""
     if 'char_persona' in data and data['char_persona'] != '':
     if 'char_persona' in data and data['char_persona'] != '':
@@ -344,25 +348,26 @@ def load_character(_character, name1, name2):
     shared.history['visible'] = []
     shared.history['visible'] = []
     if _character != 'None':
     if _character != 'None':
         shared.character = _character
         shared.character = _character
-        
+
         for extension in  ["yml", "yaml", "json"]:
         for extension in  ["yml", "yaml", "json"]:
             filepath = Path(f'characters/{_character}.{extension}')
             filepath = Path(f'characters/{_character}.{extension}')
             if filepath.exists():
             if filepath.exists():
                 break
                 break
         data = yaml.safe_load(open(filepath, 'r', encoding='utf-8').read())
         data = yaml.safe_load(open(filepath, 'r', encoding='utf-8').read())
 
 
+        name2 = data['name'] if 'name' in data else data['char_name']
+        for field in ['context', 'greeting', 'example_dialogue', 'char_persona', 'char_greeting', 'world_scenario']:
+            if field in data:
+                data[field] = replace_character_names(data[field], name1, name2)
+
         if 'context' in data:
         if 'context' in data:
             context = f"{data['context'].strip()}\n\n"
             context = f"{data['context'].strip()}\n\n"
-            name2 = data['name']
             greeting_field = 'greeting'
             greeting_field = 'greeting'
         else:
         else:
             context = build_pygmalion_style_context(data)
             context = build_pygmalion_style_context(data)
-            name2 = data['char_name']
             greeting_field = 'char_greeting'
             greeting_field = 'char_greeting'
 
 
         if 'example_dialogue' in data and data['example_dialogue'] != '':
         if 'example_dialogue' in data and data['example_dialogue'] != '':
-            data['example_dialogue'] = data['example_dialogue'].replace('{{user}}', name1).replace('{{char}}', name2)
-            data['example_dialogue'] = data['example_dialogue'].replace('<USER>', name1).replace('<BOT>', name2)
             context += f"{data['example_dialogue'].strip()}\n"
             context += f"{data['example_dialogue'].strip()}\n"
         if greeting_field in data and len(data[greeting_field].strip()) > 0:
         if greeting_field in data and len(data[greeting_field].strip()) > 0:
             shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data[greeting_field]]]
             shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data[greeting_field]]]