Selaa lähdekoodia

Stop generating in chat mode when \nYou: is generated

oobabooga 3 vuotta sitten
vanhempi
commit
df2e910421
1 muutettua tiedostoa jossa 9 lisäystä ja 5 poistoa
  1. 9 5
      server.py

+ 9 - 5
server.py

@@ -143,7 +143,6 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
     preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
     preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
     cuda = ".cuda()" if args.cpu else ""
     cuda = ".cuda()" if args.cpu else ""
     for i in range(tokens):
     for i in range(tokens):
-
         if eos_token is None:
         if eos_token is None:
             output = eval(f"model.generate(input_ids, {preset}){cuda}")
             output = eval(f"model.generate(input_ids, {preset}){cuda}")
         else:
         else:
@@ -246,6 +245,7 @@ elif args.chat or args.cai_chat:
         eos_token = '\n' if check else None
         eos_token = '\n' if check else None
         for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
         for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
             reply = i[0]
             reply = i[0]
+            next_character_found = False
 
 
             if check:
             if check:
                 idx = reply.rfind(question[-1024:])
                 idx = reply.rfind(question[-1024:])
@@ -256,6 +256,7 @@ elif args.chat or args.cai_chat:
                 idx = reply.find(f"\n{name1}:")
                 idx = reply.find(f"\n{name1}:")
                 if idx != -1:
                 if idx != -1:
                     reply = reply[:idx]
                     reply = reply[:idx]
+                    next_character_found = True
                 reply = clean_chat_message(reply)
                 reply = clean_chat_message(reply)
 
 
             history[-1] = [text, reply]
             history[-1] = [text, reply]
@@ -263,14 +264,17 @@ elif args.chat or args.cai_chat:
             # Prevent the chat log from flashing if something like "\nYo" is generated just
             # Prevent the chat log from flashing if something like "\nYo" is generated just
             # before "\nYou:" is completed
             # before "\nYou:" is completed
             tmp = f"\n{name1}:"
             tmp = f"\n{name1}:"
-            found = False
-            for j in range(1, len(tmp)):
+            next_character_substring_found = False
+            for j in range(1, len(tmp)+1):
                 if reply[-j:] == tmp[:j]:
                 if reply[-j:] == tmp[:j]:
-                    found = True
+                    next_character_substring_found = True
 
 
-            if not found:
+            if not next_character_substring_found:
                 yield history
                 yield history
 
 
+            if next_character_found:
+                break
+
     def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
     def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
         for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
         for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
             yield generate_chat_html(history, name1, name2)
             yield generate_chat_html(history, name1, name2)