소스 검색

Improve example dialogue handling

oobabooga 3 년 전
부모
커밋
aadf4e899a
2개의 변경된 파일17개의 추가작업 그리고 21개의 파일을 삭제
  1. 1 8
      html_generator.py
  2. 16 13
      server.py

+ 1 - 8
html_generator.py

@@ -161,7 +161,7 @@ def generate_4chan_html(f):
 
     return output
 
-def generate_chat_html(_history, name1, name2, character):
+def generate_chat_html(history, name1, name2, character):
     css = """
     .chat {
       margin-left: auto;
@@ -234,13 +234,6 @@ def generate_chat_html(_history, name1, name2, character):
             img = f'<img src="file/{i}">'
             break
 
-    history = copy.deepcopy(_history)
-    for i in range(len(history)):
-        if '<|BEGIN-VISIBLE-CHAT|>' in history[i][0]:
-            history[i][0] = history[i][0].replace('<|BEGIN-VISIBLE-CHAT|>', '')
-            history = history[i:]
-            break
-
     for i,_row in enumerate(history[::-1]):
         row = _row.copy()
         row[0] = re.sub(r"[\\]*\*", r"*", row[0])

+ 16 - 13
server.py

@@ -98,7 +98,7 @@ def load_model(model_name):
                     settings.append(f"max_memory={{0: '{args.gpu_memory}GiB', 'cpu': '99GiB'}}")
             if args.disk:
                 if args.disk_cache_dir is not None:
-                    settings.append("offload_folder='"+args.disk_cache_dir+"'")
+                    settings.append(f"offload_folder='{args.disk_cache_dir}'")
                 else:
                     settings.append("offload_folder='cache'")
             if args.load_in_8bit:
@@ -265,6 +265,15 @@ if args.chat or args.cai_chat:
         question = question.replace('<|BEGIN-VISIBLE-CHAT|>', '')
         return question
 
+    def remove_example_dialogue_from_history(history):
+        _history = copy.deepcopy(history)
+        for i in range(len(_history)):
+            if '<|BEGIN-VISIBLE-CHAT|>' in _history[i][0]:
+                _history[i][0] = _history[i][0].replace('<|BEGIN-VISIBLE-CHAT|>', '')
+                _history = _history[i:]
+                break
+        return _history
+
     def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
         question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
         history.append(['', ''])
@@ -300,9 +309,9 @@ if args.chat or args.cai_chat:
                     next_character_substring_found = True
 
             if not next_character_substring_found:
-                yield history
+                yield remove_example_dialogue_from_history(history)
 
-        yield history
+        yield remove_example_dialogue_from_history(history)
 
     def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
         for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
@@ -327,17 +336,10 @@ if args.chat or args.cai_chat:
         return generate_chat_html(history, name1, name2, character)
 
     def save_history():
-        _history = copy.deepcopy(history)
-        for i in range(len(_history)):
-            if '<|BEGIN-VISIBLE-CHAT|>' in history[i][0]:
-                _history[i][0] = _history[i][0].replace('<|BEGIN-VISIBLE-CHAT|>', '')
-                _history = _history[i:]
-                break
-
         if not Path('logs').exists():
             Path('logs').mkdir()
         with open(Path('logs/conversation.json'), 'w') as f:
-            f.write(json.dumps({'data': _history}))
+            f.write(json.dumps({'data': history}))
         return Path('logs/conversation.json')
 
     def load_history(file):
@@ -389,10 +391,11 @@ if args.chat or args.cai_chat:
             context = settings['context_pygmalion']
             name2 = settings['name2_pygmalion']
 
+        _history = remove_example_dialogue_from_history(history)
         if args.cai_chat:
-            return name2, context, generate_chat_html(history, name1, name2, character)
+            return name2, context, generate_chat_html(_history, name1, name2, character)
         else:
-            return name2, context, history
+            return name2, context, _history
 
     suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else ''
     with gr.Blocks(css=css+".h-\[40vh\] {height: 66.67vh} .gradio-container {max-width: 800px; margin-left: auto; margin-right: auto}", analytics_enabled=False) as interface: