Procházet zdrojové kódy

"character greeting" displayed and editable on the fly (#743)

* Add greetings field

* add greeting field and make it interactive

* Minor changes

* Fix a bug

* Simplify clear_chat_log

* Change a label

* Minor change

* Simplifications

* Simplification

* Simplify loading the default character history

* Fix regression

---------

Co-authored-by: oobabooga
OWKenobi před 2 roky
rodič
revize
dcf61a8897
3 změnil soubory, kde provedl 35 přidání a 41 odebrání
  1. 29 37
      modules/chat.py
  2. 1 0
      modules/shared.py
  3. 5 4
      server.py

+ 29 - 37
modules/chat.py

@@ -35,7 +35,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
     while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
         rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
         prev_user_input = shared.history['internal'][i][0]
-        if len(prev_user_input) > 0 and prev_user_input != '<|BEGIN-VISIBLE-CHAT|>':
+        if prev_user_input not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
             rows.insert(1, f"{name1}: {prev_user_input.strip()}\n")
         i -= 1
 
@@ -198,7 +198,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
             yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
 
 def remove_last_message(name1, name2):
-    if len(shared.history['visible']) > 0 and not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
+    if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
         last = shared.history['visible'].pop()
         shared.history['internal'].pop()
     else:
@@ -228,21 +228,13 @@ def replace_last_reply(text, name1, name2):
 def clear_html():
     return generate_chat_html([], "", "", shared.character)
 
-def clear_chat_log(name1, name2):
-    if shared.character != 'None':
-        found = False
-        for i in range(len(shared.history['internal'])):
-            if '<|BEGIN-VISIBLE-CHAT|>' in shared.history['internal'][i][0]:
-                shared.history['visible'] = [['', apply_extensions(shared.history['internal'][i][1], "output")]]
-                shared.history['internal'] = [shared.history['internal'][i]]
-                found = True
-                break
-        if not found:
-            shared.history['visible'] = []
-            shared.history['internal'] = []
-    else:
-        shared.history['internal'] = []
-        shared.history['visible'] = []
+def clear_chat_log(name1, name2, greeting):
+    shared.history['visible'] = []
+    shared.history['internal'] = []
+
+    if greeting != '':
+        shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
+        shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
 
     return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
 
@@ -287,11 +279,10 @@ def tokenize_dialogue(dialogue, name1, name2):
     return history
 
 def save_history(timestamp=True):
-    prefix = '' if shared.character == 'None' else f"{shared.character}_"
     if timestamp:
-        fname = f"{prefix}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
+        fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
     else:
-        fname = f"{prefix}persistent.json"
+        fname = f"{shared.character}_persistent.json"
     if not Path('logs').exists():
         Path('logs').mkdir()
     with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
@@ -322,14 +313,6 @@ def load_history(file, name1, name2):
         shared.history['internal'] = tokenize_dialogue(file, name1, name2)
         shared.history['visible'] = copy.deepcopy(shared.history['internal'])
 
-def load_default_history(name1, name2):
-    shared.character = 'None'
-    if Path('logs/persistent.json').exists():
-        load_history(open(Path('logs/persistent.json'), 'rb').read(), name1, name2)
-    else:
-        shared.history['internal'] = []
-        shared.history['visible'] = []
-
 def replace_character_names(text, name1, name2):
     text = text.replace('{{user}}', name1).replace('{{char}}', name2)
     return text.replace('<USER>', name1).replace('<BOT>', name2)
@@ -343,20 +326,24 @@ def build_pygmalion_style_context(data):
     context = f"{context.strip()}\n<START>\n"
     return context
 
-def load_character(_character, name1, name2):
+def load_character(character, name1, name2):
+    shared.character = character
     shared.history['internal'] = []
     shared.history['visible'] = []
-    if _character != 'None':
-        shared.character = _character
+    greeting = ""
 
+    if character != 'None':
         for extension in ["yml", "yaml", "json"]:
-            filepath = Path(f'characters/{_character}.{extension}')
+            filepath = Path(f'characters/{character}.{extension}')
             if filepath.exists():
                 break
         file_contents = open(filepath, 'r', encoding='utf-8').read()
         data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents)
 
+        if 'your_name' in data and data['your_name'] != '':
+            name1 = data['your_name']
         name2 = data['name'] if 'name' in data else data['char_name']
+
         for field in ['context', 'greeting', 'example_dialogue', 'char_persona', 'char_greeting', 'world_scenario']:
             if field in data:
                 data[field] = replace_character_names(data[field], name1, name2)
@@ -371,20 +358,25 @@ def load_character(_character, name1, name2):
         if 'example_dialogue' in data and data['example_dialogue'] != '':
             context += f"{data['example_dialogue'].strip()}\n"
         if greeting_field in data and len(data[greeting_field].strip()) > 0:
-            shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data[greeting_field]]]
-            shared.history['visible'] += [['', apply_extensions(data[greeting_field], "output")]]
+            greeting = data[greeting_field]  
     else:
-        shared.character = 'None'
         context = shared.settings['context']
         name2 = shared.settings['name2']
+        greeting = shared.settings['greeting'] 
 
     if Path(f'logs/{shared.character}_persistent.json').exists():
         load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
+    elif greeting != "":
+        shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
+        shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
 
     if shared.args.cai_chat:
-        return name2, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character)
+        return name1, name2, greeting, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character)
     else:
-        return name2, context, shared.history['visible']
+        return name1, name2, greeting, context, shared.history['visible']
+
+def load_default_history(name1, name2):
+    load_character("None", name1, name2)
 
 def upload_character(json_file, img, tavern=False):
     json_file = json_file if type(json_file) == str else json_file.decode('utf-8')

+ 1 - 0
modules/shared.py

@@ -31,6 +31,7 @@ settings = {
     'name1': 'You',
     'name2': 'Assistant',
     'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
+    'greeting': 'Hello there!',
     'stop_at_newline': False,
     'chat_prompt_size': 2048,
     'chat_prompt_size_min': 0,

+ 5 - 4
server.py

@@ -317,8 +317,9 @@ def create_interface():
 
             with gr.Tab("Character", elem_id="chat-settings"):
                 shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
-                shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Bot\'s name')
-                shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=5, label='Context')
+                shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character''s name')
+                shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=2, label='Greeting')
+                shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=8, label='Context')
                 with gr.Row():
                     shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
                     ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
@@ -381,7 +382,7 @@ def create_interface():
             clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
             shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
             shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
-            shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
+            shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2'], shared.gradio['greeting']], shared.gradio['display'])
             shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
 
             shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
@@ -396,7 +397,7 @@ def create_interface():
             shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
             shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
 
-            shared.gradio['character_menu'].change(chat.load_character, [shared.gradio['character_menu'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['name2'], shared.gradio['context'], shared.gradio['display']])
+            shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'context', 'display']])
             shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
             shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
             shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])