Преглед на файлове

Make the bos token optional

oobabooga преди 2 години
родител
ревизия
bd04ff27ad
променени са 3 файла, в които са добавени 12 реда и са изтрити 4 реда
  1. 1 0
      modules/shared.py
  2. 8 2
      modules/text_generation.py
  3. 3 2
      server.py

+ 1 - 0
modules/shared.py

@@ -35,6 +35,7 @@ settings = {
     'greeting': 'Hello there!',
     'end_of_turn': '',
     'stop_at_newline': False,
+    'add_bos_token': True,
     'chat_prompt_size': 2048,
     'chat_prompt_size_min': 0,
     'chat_prompt_size_max': 2048,

+ 8 - 2
modules/text_generation.py

@@ -22,7 +22,7 @@ def get_max_prompt_length(tokens):
     return max_length
 
 
-def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
+def encode(prompt, tokens_to_generate=0, add_special_tokens=True, add_bos_token=True):
     if any((shared.is_RWKV, shared.is_llamacpp)):
         input_ids = shared.tokenizer.encode(str(prompt))
         input_ids = np.array(input_ids).reshape(1, len(input_ids))
@@ -30,6 +30,12 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
     else:
         input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
 
+        # This is a hack for making replies more creative.
+        if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
+            input_ids = input_ids[:, 1:]
+
+        # Llama adds this extra token when the first character is '\n', and this
+        # compromises the stopping criteria, so we just remove it
         if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
             input_ids = input_ids[:, 1:]
 
@@ -158,7 +164,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
             print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
             return
 
-    input_ids = encode(question, generate_state['max_new_tokens'])
+    input_ids = encode(question, generate_state['max_new_tokens'], add_bos_token=generate_state['add_bos_token'])
     original_input_ids = input_ids
     output = input_ids[0]
 

+ 3 - 2
server.py

@@ -233,7 +233,7 @@ def create_model_menus():
 
 def create_settings_menus(default_preset):
     generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
-    for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
+    for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts', 'add_bos_token']:
         generate_params[k] = shared.settings[k]
     shared.gradio['generate_state'] = gr.State(generate_params)
 
@@ -273,6 +273,7 @@ def create_settings_menus(default_preset):
                     with gr.Column():
                         shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
                 shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
+            shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
 
     with gr.Accordion('Soft prompt', open=False):
         with gr.Row():
@@ -610,7 +611,7 @@ def create_interface():
             d[key] = value
             return d
 
-        for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
+        for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'add_bos_token', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
             if k not in shared.gradio:
                 continue
             if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]: