Browse Source

Add 'hallucinations' filter #326

This breaks the API since a new parameter has been added.
It should be a one-line fix. See api-example.py.
oobabooga 3 years ago
parent
commit
9d6a625bd6
5 changed files with 25 additions and 18 deletions
  1. 2 0
      api-example-stream.py
  2. 2 0
      api-example.py
  3. 8 8
      modules/chat.py
  4. 2 1
      modules/text_generation.py
  5. 11 9
      server.py

+ 2 - 0
api-example-stream.py

@@ -26,6 +26,7 @@ async def run(context):
         'top_p': 0.9,
         'typical_p': 1,
         'repetition_penalty': 1.05,
+        'encoder_repetition_penalty': 1.0,
         'top_k': 0,
         'min_length': 0,
         'no_repeat_ngram_size': 0,
@@ -59,6 +60,7 @@ async def run(context):
                             params['top_p'],
                             params['typical_p'],
                             params['repetition_penalty'],
+                            params['encoder_repetition_penalty'],
                             params['top_k'],
                             params['min_length'],
                             params['no_repeat_ngram_size'],

+ 2 - 0
api-example.py

@@ -24,6 +24,7 @@ params = {
     'top_p': 0.9,
     'typical_p': 1,
     'repetition_penalty': 1.05,
+    'encoder_repetition_penalty': 1.0,
     'top_k': 0,
     'min_length': 0,
     'no_repeat_ngram_size': 0,
@@ -45,6 +46,7 @@ response = requests.post(f"http://{server}:7860/run/textgen", json={
         params['top_p'],
         params['typical_p'],
         params['repetition_penalty'],
+        params['encoder_repetition_penalty'],
         params['top_k'],
         params['min_length'],
         params['no_repeat_ngram_size'],

+ 8 - 8
modules/chat.py

@@ -97,7 +97,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
 def stop_everything_event():
     shared.stop_everything = True
 
-def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
+def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
     shared.stop_everything = False
     just_started = True
     eos_token = '\n' if check else None
@@ -133,7 +133,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
     # Generate
     reply = ''
     for i in range(chat_generation_attempts):
-        for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
+        for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
 
             # Extracting the reply
             reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check)
@@ -160,7 +160,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
 
     yield shared.history['visible']
 
-def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
+def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
     eos_token = '\n' if check else None
 
     if 'pygmalion' in shared.model_name.lower():
@@ -172,18 +172,18 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
     # Yield *Is typing...*
     yield shared.processing_message
     for i in range(chat_generation_attempts):
-        for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
+        for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
             reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
             yield reply
             if next_character_found:
                 break
         yield reply
 
-def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
-    for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
+def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
+    for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
         yield generate_chat_html(_history, name1, name2, shared.character)
 
-def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
+def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
     if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
         yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
     else:
@@ -191,7 +191,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
         last_internal = shared.history['internal'].pop()
         # Yield '*Is typing...*'
         yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
-        for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
+        for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
             if shared.args.cai_chat:
                 shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
             else:

+ 2 - 1
modules/text_generation.py

@@ -89,7 +89,7 @@ def clear_torch_cache():
     if not shared.args.cpu:
         torch.cuda.empty_cache()
 
-def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
+def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
     clear_torch_cache()
     t0 = time.time()
 
@@ -143,6 +143,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
             "top_p": top_p,
             "typical_p": typical_p,
             "repetition_penalty": repetition_penalty,
+            "encoder_repetition_penalty": encoder_repetition_penalty,
             "top_k": top_k,
             "min_length": min_length if shared.args.no_stream else 0,
             "no_repeat_ngram_size": no_repeat_ngram_size,

+ 11 - 9
server.py

@@ -66,6 +66,7 @@ def load_preset_values(preset_menu, return_dict=False):
         'top_p': 1,
         'typical_p': 1,
         'repetition_penalty': 1,
+        'encoder_repetition_penalty': 1,
         'top_k': 50,
         'num_beams': 1,
         'penalty_alpha': 0,
@@ -86,7 +87,7 @@ def load_preset_values(preset_menu, return_dict=False):
     if return_dict:
         return generate_params
     else:
-        return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
+        return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
 
 def upload_soft_prompt(file):
     with zipfile.ZipFile(io.BytesIO(file)) as zf:
@@ -117,14 +118,15 @@ def create_settings_menus(default_preset):
         with gr.Row():
             with gr.Column():
                 shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
-                shared.gradio['repetition_penalty'] = gr.Slider(1.0, 2.99, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
-                shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
                 shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
-            with gr.Column():
-                shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
+                shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
                 shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
+            with gr.Column():
+                shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
+                shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'],step=0.01,label='encoder_repetition_penalty')
                 shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
                 shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
+        shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
 
         gr.Markdown('Contrastive search:')
         shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
@@ -147,7 +149,7 @@ def create_settings_menus(default_preset):
             shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
 
     shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
-    shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
+    shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['encoder_repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
     shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
     shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
 
@@ -262,7 +264,7 @@ if shared.args.chat or shared.args.cai_chat:
                 shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
             create_settings_menus(default_preset)
 
-        shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
+        shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
         if shared.args.extensions is not None:
             with gr.Tab('Extensions'):
                 extensions_module.create_extensions_block()
@@ -329,7 +331,7 @@ elif shared.args.notebook:
         if shared.args.extensions is not None:
             extensions_module.create_extensions_block()
 
-        shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
+        shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
         output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
         gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
         gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
@@ -361,7 +363,7 @@ else:
                 with gr.Tab('HTML'):
                     shared.gradio['html'] = gr.HTML()
 
-        shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
+        shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
         output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
         gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
         gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))