Bladeren bron

Set chat prompt size in tokens

oobabooga 3 jaren geleden
bovenliggende
commit
7be372829d
2 gewijzigde bestanden met toevoegingen van 19 en 21 verwijderingen
  1. 16 18
      server.py
  2. 3 3
      settings-template.json

+ 16 - 18
server.py

@@ -71,9 +71,9 @@ settings = {
     'prompt': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
     'prompt_gpt4chan': '-----\n--- 865467536\nInput text\n--- 865467537\n',
     'stop_at_newline': True,
-    'history_size': 0,
-    'history_size_min': 0,
-    'history_size_max': 64,
+    'chat_prompt_size': 2048,
+    'chat_prompt_size_min': 0,
+    'chat_prompt_size_max': 2048,
     'preset_pygmalion': 'Pygmalion',
     'name1_pygmalion': 'You',
     'name2_pygmalion': 'Kawaii',
@@ -503,13 +503,13 @@ def clean_chat_message(text):
     text = text.strip()
     return text
 
-def generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=False):
+def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=False):
     text = clean_chat_message(text)
 
     rows = [f"{context.strip()}\n"]
     i = len(history['internal'])-1
     count = 0
-    max_length = get_max_prompt_length(tokens)
+    max_length = min(get_max_prompt_length(tokens), chat_prompt_size)
     while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length:
         rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n")
         count += 1
@@ -517,8 +517,6 @@ def generate_chat_prompt(text, tokens, name1, name2, context, history_size, impe
             rows.insert(1, f"{name1}: {history['internal'][i][0].strip()}\n")
             count += 1
         i -= 1
-        if history_size != 0 and count >= history_size:
-            break
 
     if not impersonate:
         rows.append(f"{name1}: {text}\n")
@@ -566,14 +564,14 @@ def extract_message_from_reply(question, reply, current, other, check, extension
 
     return reply, next_character_found, substring_found
 
-def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture=None):
+def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
     if args.picture and picture is not None:
         text, visible_text = generate_chat_picture(picture, name1, name2)
     else:
         visible_text = text
 
     text = apply_extensions(text, "input")
-    question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
+    question = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size)
     history['internal'].append(['', ''])
     history['visible'].append(['', ''])
     eos_token = '\n' if check else None
@@ -587,8 +585,8 @@ def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p,
             break
     yield history['visible']
 
-def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture=None):
-    question = generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=True)
+def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
+    question = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=True)
     eos_token = '\n' if check else None
     for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
         reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False)
@@ -598,19 +596,19 @@ def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, to
             break
     yield apply_extensions(reply, "output")
 
-def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture=None):
-    for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture):
+def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
+    for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
         yield generate_chat_html(_history, name1, name2, character)
 
-def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture=None):
+def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
     last = history['visible'].pop()
     history['internal'].pop()
     text = last[0]
     if args.cai_chat:
-        for i in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture):
+        for i in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
             yield i
     else:
-        for i in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture):
+        for i in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
             yield i
 
 def remove_last_message(name1, name2):
@@ -886,7 +884,7 @@ if args.chat or args.cai_chat:
             with gr.Column():
                 max_new_tokens = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
             with gr.Column():
-                history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size in prompt (0 for no limit)', value=settings['history_size'])
+                chat_prompt_size_slider = gr.Slider(minimum=settings['chat_prompt_size_min'], maximum=settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=settings['chat_prompt_size'])
 
         preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus()
 
@@ -926,7 +924,7 @@ if args.chat or args.cai_chat:
         if args.extensions is not None:
             create_extensions_block()
 
-        input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size_slider]
+        input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size_slider]
         if args.picture:
             input_params.append(picture_select)
         if args.cai_chat:

+ 3 - 3
settings-template.json

@@ -9,9 +9,9 @@
     "prompt": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
     "prompt_gpt4chan": "-----\n--- 865467536\nInput text\n--- 865467537\n",
     "stop_at_newline": true,
-    "history_size": 0,
-    "history_size_min": 0,
-    "history_size_max": 64,
+    "chat_prompt_size": 2048,
+    "chat_prompt_size_min": 0,
+    "chat_prompt_size_max": 2048,
     "preset_pygmalion": "Pygmalion",
     "name1_pygmalion": "You",
     "name2_pygmalion": "Kawaii",