소스 검색

Add a history size parameter to the chat

If too many messages are used in the prompt, the model
gets really slow. It is useful to have the ability to
limit this.
oobabooga 3 년 전
부모
커밋
185587a33e
2개의 변경된 파일24개의 추가작업 그리고 10개의 파일을 삭제
  1. 21 10
      server.py
  2. 3 0
      settings-template.json

+ 21 - 10
server.py

@@ -49,6 +49,9 @@ settings = {
     'prompt': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
     'prompt_gpt4chan': '-----\n--- 865467536\nInput text\n--- 865467537\n',
     'stop_at_newline': True,
+    'history_size': 8,
+    'history_size_min': 0,
+    'history_size_max': 64,
     'preset_pygmalion': 'Pygmalion',
     'name1_pygmalion': 'You',
     'name2_pygmalion': 'Kawaii',
@@ -229,16 +232,21 @@ if args.chat or args.cai_chat:
         text = text.strip()
         return text
 
-    def generate_chat_prompt(text, tokens, name1, name2, context):
+    def generate_chat_prompt(text, tokens, name1, name2, context, history_size):
         text = clean_chat_message(text)
 
         rows = [f"{context.strip()}\n"]
         i = len(history)-1
+        count = 0
         while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens:
             rows.insert(1, f"{name2}: {history[i][1].strip()}\n")
+            count += 1
             if not (i == 0 and len(history[i][0]) == 0):
                 rows.insert(1, f"{name1}: {history[i][0].strip()}\n")
+                count += 1
             i -= 1
+            if history_size != 0 and count >= history_size:
+                break
         rows.append(f"{name1}: {text}\n")
         rows.append(f"{name2}:")
 
@@ -247,10 +255,11 @@ if args.chat or args.cai_chat:
             rows.pop(1)
 
         question = ''.join(rows)
+        print(question)
         return question
 
-    def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
-        question = generate_chat_prompt(text, tokens, name1, name2, context)
+    def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
+        question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
         history.append(['', ''])
         eos_token = '\n' if check else None
         for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
@@ -288,8 +297,8 @@ if args.chat or args.cai_chat:
 
         yield history
 
-    def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
-        for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
+    def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
+        for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
             yield generate_chat_html(history, name1, name2, character)
 
     def remove_last_message(name1, name2):
@@ -362,11 +371,12 @@ if args.chat or args.cai_chat:
             stop = gr.Button("Stop")
             btn3 = gr.Button("Remove last message")
 
-        length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
         with gr.Row():
             with gr.Column():
+                length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
                 model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
             with gr.Column():
+                history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size (0 for no limit)', value=settings['history_size'])
                 preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'], label='Settings preset')
 
         name1 = gr.Textbox(value=settings[f'name1{suffix}'], lines=1, label='Your name')
@@ -385,13 +395,14 @@ if args.chat or args.cai_chat:
                 save_btn = gr.Button(value="Click me")
                 download = gr.File()
 
+        input_params = [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check, history_size_slider]
         if args.cai_chat:
-            gen_event = btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream, api_name="textgen")
-            gen_event2 = textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream)
+            gen_event = btn.click(cai_chatbot_wrapper, input_params, display1, show_progress=args.no_stream, api_name="textgen")
+            gen_event2 = textbox.submit(cai_chatbot_wrapper, input_params, display1, show_progress=args.no_stream)
             btn2.click(clear_html, [], display1, show_progress=False)
         else:
-            gen_event = btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream, api_name="textgen")
-            gen_event2 = textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream)
+            gen_event = btn.click(chatbot_wrapper, input_params, display1, show_progress=args.no_stream, api_name="textgen")
+            gen_event2 = textbox.submit(chatbot_wrapper, input_params, display1, show_progress=args.no_stream)
             btn2.click(lambda x: "", display1, display1, show_progress=False)
 
         btn2.click(clear)

+ 3 - 0
settings-template.json

@@ -9,6 +9,9 @@
     "prompt": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
     "prompt_gpt4chan": "-----\n--- 865467536\nInput text\n--- 865467537\n",
     "stop_at_newline": true,
+    "history_size": 8,
+    "history_size_min": 0,
+    "history_size_max": 64,
     "preset_pygmalion": "Pygmalion",
     "name1_pygmalion": "You",
     "name2_pygmalion": "Kawaii",