Преглед изворни кода

Move chat history into shared module

oobabooga пре 2 година
родитељ
комит
2e86a1ec04
3 измењених фајлова са 83 додато и 86 уклоњено
  1. 74 81
      modules/chat.py
  2. 4 0
      modules/shared.py
  3. 5 5
      server.py

+ 74 - 81
modules/chat.py

@@ -17,9 +17,6 @@ from modules.text_generation import encode, generate_reply, get_max_prompt_lengt
 if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
     import modules.bot_picture as bot_picture
 
-history = {'internal': [], 'visible': []}
-character = None
-
 # This gets the new line characters right.
 def clean_chat_message(text):
     text = text.replace('\n', '\n\n')
@@ -30,7 +27,7 @@ def clean_chat_message(text):
 def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=False):
     text = clean_chat_message(text)
     rows = [f"{context.strip()}\n"]
-    i = len(history['internal'])-1
+    i = len(shared.history['internal'])-1
     count = 0
 
     if shared.soft_prompt:
@@ -38,10 +35,10 @@ def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size,
     max_length = min(get_max_prompt_length(tokens), chat_prompt_size)
 
     while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length:
-        rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n")
+        rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
         count += 1
-        if not (history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
-            rows.insert(1, f"{name1}: {history['internal'][i][0].strip()}\n")
+        if not (shared.history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
+            rows.insert(1, f"{name1}: {shared.history['internal'][i][0].strip()}\n")
             count += 1
         i -= 1
 
@@ -130,20 +127,20 @@ def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p,
         # We need this global variable to handle the Stop event,
         # otherwise gradio gets confused
         if stop_everything:
-            return history['visible']
+            return shared.history['visible']
 
         if first:
             first = False
-            history['internal'].append(['', ''])
-            history['visible'].append(['', ''])
+            shared.history['internal'].append(['', ''])
+            shared.history['visible'].append(['', ''])
 
-        history['internal'][-1] = [text, reply]
-        history['visible'][-1] = [visible_text, visible_reply]
+        shared.history['internal'][-1] = [text, reply]
+        shared.history['visible'][-1] = [visible_text, visible_reply]
         if not substring_found:
-            yield history['visible']
+            yield shared.history['visible']
         if next_character_found:
             break
-    yield history['visible']
+    yield shared.history['visible']
 
 def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
     if 'pygmalion' in shared.model_name.lower():
@@ -161,78 +158,76 @@ def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, to
 
 def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
     for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
-        yield generate_chat_html(_history, name1, name2, character)
+        yield generate_chat_html(_history, name1, name2, shared.character)
 
 def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
-    if character is not None and len(history['visible']) == 1:
+    if shared.character is not None and len(shared.history['visible']) == 1:
         if shared.args.cai_chat:
-            yield generate_chat_html(history['visible'], name1, name2, character)
+            yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
         else:
-            yield history['visible']
+            yield shared.history['visible']
     else:
-        last_visible = history['visible'].pop()
-        last_internal = history['internal'].pop()
+        last_visible = shared.history['visible'].pop()
+        last_internal = shared.history['internal'].pop()
 
         for _history in chatbot_wrapper(last_internal[0], tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
             if shared.args.cai_chat:
-                history['visible'][-1] = [last_visible[0], _history[-1][1]]
-                yield generate_chat_html(history['visible'], name1, name2, character)
+                shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
+                yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
             else:
-                history['visible'][-1] = (last_visible[0], _history[-1][1])
-                yield history['visible']
+                shared.history['visible'][-1] = (last_visible[0], _history[-1][1])
+                yield shared.history['visible']
 
 def remove_last_message(name1, name2):
-    if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
-        last = history['visible'].pop()
-        history['internal'].pop()
+    if not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
+        last = shared.history['visible'].pop()
+        shared.history['internal'].pop()
     else:
         last = ['', '']
     if shared.args.cai_chat:
-        return generate_chat_html(history['visible'], name1, name2, character), last[0]
+        return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0]
     else:
-        return history['visible'], last[0]
+        return shared.history['visible'], last[0]
 
 def send_last_reply_to_input():
-    if len(history['internal']) > 0:
-        return history['internal'][-1][1]
+    if len(shared.history['internal']) > 0:
+        return shared.history['internal'][-1][1]
     else:
         return ''
 
 def replace_last_reply(text, name1, name2):
-    if len(history['visible']) > 0:
+    if len(shared.history['visible']) > 0:
         if shared.args.cai_chat:
-            history['visible'][-1][1] = text
+            shared.history['visible'][-1][1] = text
         else:
-            history['visible'][-1] = (history['visible'][-1][0], text)
-        history['internal'][-1][1] = apply_extensions(text, "input")
+            shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
+        shared.history['internal'][-1][1] = apply_extensions(text, "input")
 
     if shared.args.cai_chat:
-        return generate_chat_html(history['visible'], name1, name2, character)
+        return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
     else:
-        return history['visible']
+        return shared.history['visible']
 
 def clear_html():
-    return generate_chat_html([], "", "", character)
-
-def clear_chat_log(_character, name1, name2):
-    global history
-    if _character != 'None':
-        for i in range(len(history['internal'])):
-            if '<|BEGIN-VISIBLE-CHAT|>' in history['internal'][i][0]:
-                history['visible'] = [['', history['internal'][i][1]]]
-                history['internal'] = history['internal'][:i+1]
+    return generate_chat_html([], "", "", shared.character)
+
+def clear_chat_log(name1, name2):
+    if shared.character != 'None':
+        for i in range(len(shared.history['internal'])):
+            if '<|BEGIN-VISIBLE-CHAT|>' in shared.history['internal'][i][0]:
+                shared.history['visible'] = [['', shared.history['internal'][i][1]]]
+                shared.history['internal'] = shared.history['internal'][:i+1]
                 break
     else:
-        history['internal'] = []
-        history['visible'] = []
+        shared.history['internal'] = []
+        shared.history['visible'] = []
     if shared.args.cai_chat:
-        return generate_chat_html(history['visible'], name1, name2, character)
+        return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
     else:
-        return history['visible'] 
+        return shared.history['visible']
 
 def redraw_html(name1, name2):
-    global history
-    return generate_chat_html(history['visible'], name1, name2, character)
+    return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
 
 def tokenize_dialogue(dialogue, name1, name2):
     _history = []
@@ -273,47 +268,45 @@ def tokenize_dialogue(dialogue, name1, name2):
 
 def save_history(timestamp=True):
     if timestamp:
-        fname = f"{character or ''}{'_' if character else ''}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
+        fname = f"{shared.character or ''}{'_' if shared.character else ''}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
     else:
-        fname = f"{character or ''}{'_' if character else ''}persistent.json"
+        fname = f"{shared.character or ''}{'_' if shared.character else ''}persistent.json"
     if not Path('logs').exists():
         Path('logs').mkdir()
     with open(Path(f'logs/{fname}'), 'w') as f:
-        f.write(json.dumps({'data': history['internal'], 'data_visible': history['visible']}, indent=2))
+        f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
     return Path(f'logs/{fname}')
 
 def load_history(file, name1, name2):
-    global history
     file = file.decode('utf-8')
     try:
         j = json.loads(file)
         if 'data' in j:
-            history['internal'] = j['data']
+            shared.history['internal'] = j['data']
             if 'data_visible' in j:
-                history['visible'] = j['data_visible']
+                shared.history['visible'] = j['data_visible']
             else:
-                history['visible'] = copy.deepcopy(history['internal'])
+                shared.history['visible'] = copy.deepcopy(shared.history['internal'])
         # Compatibility with Pygmalion AI's official web UI
         elif 'chat' in j:
-            history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
+            shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
             if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
-                history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', history['internal'][0]]] + [[history['internal'][i], history['internal'][i+1]] for i in range(1, len(history['internal'])-1, 2)]
-                history['visible'] = copy.deepcopy(history['internal'])
-                history['visible'][0][0] = ''
+                shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(1, len(shared.history['internal'])-1, 2)]
+                shared.history['visible'] = copy.deepcopy(shared.history['internal'])
+                shared.history['visible'][0][0] = ''
             else:
-                history['internal'] = [[history['internal'][i], history['internal'][i+1]] for i in range(0, len(history['internal'])-1, 2)]
-                history['visible'] = copy.deepcopy(history['internal'])
+                shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(0, len(shared.history['internal'])-1, 2)]
+                shared.history['visible'] = copy.deepcopy(shared.history['internal'])
     except:
-        history['internal'] = tokenize_dialogue(file, name1, name2)
-        history['visible'] = copy.deepcopy(history['internal'])
+        shared.history['internal'] = tokenize_dialogue(file, name1, name2)
+        shared.history['visible'] = copy.deepcopy(shared.history['internal'])
 
 def load_character(_character, name1, name2):
-    global history, character
     context = ""
-    history['internal'] = []
-    history['visible'] = []
+    shared.history['internal'] = []
+    shared.history['visible'] = []
     if _character != 'None':
-        character = _character
+        shared.character = _character
         data = json.loads(open(Path(f'characters/{_character}.json'), 'r').read())
         name2 = data['char_name']
         if 'char_persona' in data and data['char_persona'] != '':
@@ -322,25 +315,25 @@ def load_character(_character, name1, name2):
             context += f"Scenario: {data['world_scenario']}\n"
         context = f"{context.strip()}\n<START>\n"
         if 'example_dialogue' in data and data['example_dialogue'] != '':
-            history['internal'] = tokenize_dialogue(data['example_dialogue'], name1, name2)
+            shared.history['internal'] = tokenize_dialogue(data['example_dialogue'], name1, name2)
         if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0:
-            history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
-            history['visible'] += [['', apply_extensions(data['char_greeting'], "output")]]
+            shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
+            shared.history['visible'] += [['', apply_extensions(data['char_greeting'], "output")]]
         else:
-            history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
-            history['visible'] += [['', "Hello there!"]]
+            shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
+            shared.history['visible'] += [['', "Hello there!"]]
     else:
-        character = None
+        shared.character = None
         context = shared.settings['context_pygmalion']
         name2 = shared.settings['name2_pygmalion']
 
-    if Path(f'logs/{character}_persistent.json').exists():
-        load_history(open(Path(f'logs/{character}_persistent.json'), 'rb').read(), name1, name2)
+    if Path(f'logs/{shared.character}_persistent.json').exists():
+        load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
 
     if shared.args.cai_chat:
-        return name2, context, generate_chat_html(history['visible'], name1, name2, character)
+        return name2, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character)
     else:
-        return name2, context, history['visible']
+        return name2, context, shared.history['visible']
 
 def upload_character(json_file, img, tavern=False):
     json_file = json_file if type(json_file) == str else json_file.decode('utf-8')

+ 4 - 0
modules/shared.py

@@ -7,6 +7,10 @@ soft_prompt_tensor = None
 soft_prompt = False
 stop_everything = False
 
+# Chat variables
+history = {'internal': [], 'visible': []}
+character = 'None'
+
 settings = {
     'max_new_tokens': 200,
     'max_new_tokens_min': 1,

+ 5 - 5
server.py

@@ -191,9 +191,9 @@ if shared.args.chat or shared.args.cai_chat:
 
     with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as interface:
         if shared.args.cai_chat:
-            display = gr.HTML(value=generate_chat_html(chat.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], chat.character))
+            display = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
         else:
-            display = gr.Chatbot(value=chat.history['visible'])
+            display = gr.Chatbot(value=shared.history['visible'])
         textbox = gr.Textbox(label='Input')
         with gr.Row():
             buttons["Stop"] = gr.Button("Stop")
@@ -272,7 +272,7 @@ if shared.args.chat or shared.args.cai_chat:
 
         buttons["Send last reply to input"].click(chat.send_last_reply_to_input, [], textbox, show_progress=shared.args.no_stream)
         buttons["Replace last reply"].click(chat.replace_last_reply, [textbox, name1, name2], display, show_progress=shared.args.no_stream)
-        buttons["Clear history"].click(chat.clear_chat_log, [character_menu, name1, name2], display)
+        buttons["Clear history"].click(chat.clear_chat_log, [name1, name2], display)
         buttons["Remove last"].click(chat.remove_last_message, [name1, name2], [display, textbox], show_progress=False)
         buttons["Download"].click(chat.save_history, inputs=[], outputs=[download])
         buttons["Upload character"].click(chat.upload_character, [upload_char, upload_img], [character_menu])
@@ -295,8 +295,8 @@ if shared.args.chat or shared.args.cai_chat:
             upload_chat_history.upload(chat.redraw_html, [name1, name2], [display])
             upload_img_me.upload(chat.redraw_html, [name1, name2], [display])
         else:
-            upload_chat_history.upload(lambda : chat.history['visible'], [], [display])
-            upload_img_me.upload(lambda : chat.history['visible'], [], [display])
+            upload_chat_history.upload(lambda : shared.history['visible'], [], [display])
+            upload_img_me.upload(lambda : shared.history['visible'], [], [display])
 
 elif shared.args.notebook:
     with gr.Blocks(css=ui.css, analytics_enabled=False) as interface: