Parcourir la source

Stop generating at \nYou: in chat mode

oobabooga il y a 3 ans
Parent
commit
3b8f0021cc
2 fichiers modifiés avec 20 ajouts et 9 suppressions
  1. 1 1
      README.md
  2. 19 8
      server.py

+ 1 - 1
README.md

@@ -153,6 +153,6 @@ Pull requests, suggestions, and issue reports are welcome.
 ## Credits
 
 - NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets
-- Pygmalion preset: https://github.com/PygmalionAI/gradio-ui/blob/master/src/gradio_ui.py
+- Pygmalion preset, code for early stopping in chat mode: https://github.com/PygmalionAI/gradio-ui/
 - Verbose preset: Anonymous 4chan user.
 - Gradio dropdown menu refresh button: https://github.com/AUTOMATIC1111/stable-diffusion-webui

+ 19 - 8
server.py

@@ -14,6 +14,7 @@ import transformers
 from transformers import AutoTokenizer, AutoModelForCausalLM
 from modules.html_generator import *
 from modules.ui import *
+from modules.stopping_criteria import _SentinelTokenStoppingCriteria
 
 transformers.logging.set_verbosity_error()
 
@@ -135,12 +136,12 @@ def fix_galactica(s):
     s = s.replace(r'$$', r'$')
     return s
 
-def encode(prompt, tokens):
+def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
     if args.cpu:
-        input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens)
+        input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens)
     else:
         torch.cuda.empty_cache()
-        input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda()
+        input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens).cuda()
     return input_ids
 
 def decode(output_ids):
@@ -161,7 +162,7 @@ def formatted_outputs(reply, model_name):
     else:
         return reply
 
-def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
+def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None, stopping_string=None):
     global model, tokenizer, model_name, loaded_preset, preset
 
     if selected_model != model_name:
@@ -179,11 +180,22 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
     cuda = "" if args.cpu else ".cuda()"
     n = None if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
     input_ids = encode(question, tokens)
+    # The stopping_criteria code below was copied from
+    # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
+    if stopping_string is not None:
+        t = encode(stopping_string, 0, add_special_tokens=False)
+        stopping_criteria_list = transformers.StoppingCriteriaList([
+            _SentinelTokenStoppingCriteria(
+            sentinel_token_ids=t,
+            starting_idx=len(input_ids[0]))
+        ])
+    else:
+        stopping_criteria_list = None
 
     # Generate the entire reply at once
     if args.no_stream:
         t0 = time.time()
-        output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
+        output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
         reply = decode(output[0])
         t1 = time.time()
         print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)")
@@ -194,11 +206,10 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
         yield formatted_outputs(question, model_name)
         preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
         for i in tqdm(range(tokens)):
-            output = eval(f"model.generate(input_ids, {preset}){cuda}")
+            output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
             reply = decode(output[0])
             if eos_token is not None and reply[-1] == eos_token:
                 break
-
             yield formatted_outputs(reply, model_name)
             input_ids = output
 
@@ -289,7 +300,7 @@ if args.chat or args.cai_chat:
         question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
         history.append(['', ''])
         eos_token = '\n' if check else None
-        for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
+        for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"):
             next_character_found = False
 
             previous_idx = [m.start() for m in re.finditer(f"(^|\n){name2}:", question)]