Prechádzať zdrojové kódy

Make chat history persistent

oobabooga 3 rokov pred
rodič
commit
b397bea387
1 zmenil súbory, kde vykonal 15 pridanie a 3 odobranie
  1. 15 3
      server.py

+ 15 - 3
server.py

@@ -701,8 +701,11 @@ def tokenize_dialogue(dialogue, name1, name2):
 
     return _history
 
-def save_history():
-    fname = f"{character or ''}{'_' if character else ''}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
+def save_history(timestamp=True):
+    if timestamp:
+        fname = f"{character or ''}{'_' if character else ''}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
+    else:
+        fname = f"{character or ''}{'_' if character else ''}persistent.json"
     if not Path('logs').exists():
         Path('logs').mkdir()
     with open(Path(f'logs/{fname}'), 'w') as f:
@@ -761,6 +764,9 @@ def load_character(_character, name1, name2):
         context = settings['context_pygmalion']
         name2 = settings['name2_pygmalion']
 
+    if Path(f'logs/{character}_persistent.json').exists():
+        load_history(open(Path(f'logs/{character}_persistent.json'), 'rb').read(), name1, name2)
+
     if args.cai_chat:
         return name2, context, generate_chat_html(history['visible'], name1, name2, character)
     else:
@@ -859,9 +865,13 @@ history = {'internal': [], 'visible': []}
 character = None
 
 if args.chat or args.cai_chat:
+
+    if Path(f'logs/persistent.json').exists():
+        load_history(open(Path(f'logs/persistent.json'), 'rb').read(), settings[f'name1{suffix}'], settings[f'name2{suffix}'])
+
     with gr.Blocks(css=css+".h-\[40vh\] {height: 66.67vh} .gradio-container {max-width: 800px; margin-left: auto; margin-right: auto} .w-screen {width: unset}", analytics_enabled=False) as interface:
         if args.cai_chat:
-            display = gr.HTML(value=generate_chat_html([], "", "", character))
+            display = gr.HTML(value=generate_chat_html(history['visible'], settings[f'name1{suffix}'], settings[f'name2{suffix}'], character))
         else:
             display = gr.Chatbot()
         textbox = gr.Textbox(label='Input')
@@ -949,8 +959,10 @@ if args.chat or args.cai_chat:
         buttons["Upload character"].click(upload_character, [upload_char, upload_img], [character_menu])
         for i in ["Generate", "Regenerate", "Replace last reply"]:
             buttons[i].click(lambda x: "", textbox, textbox, show_progress=False)
+            buttons[i].click(lambda : save_history(timestamp=False), [], [], show_progress=False)
 
         textbox.submit(lambda x: "", textbox, textbox, show_progress=False)
+        textbox.submit(lambda : save_history(timestamp=False), [], [], show_progress=False)
         character_menu.change(load_character, [character_menu, name1, name2], [name2, context, display])
         upload_img_tavern.upload(upload_tavern_character, [upload_img_tavern, name1, name2], [character_menu])
         upload.upload(load_history, [upload, name1, name2], [])