oobabooga 3 lat temu
rodzic
commit
161cae001b
1 zmienionych plików z 17 dodań i 5 usunięć
  1. 17 5
      server.py

+ 17 - 5
server.py

@@ -390,7 +390,15 @@ if args.chat or args.cai_chat:
                 next_character_found = True
             reply = clean_chat_message(reply)
 
-        return reply, next_character_found
+            # Detect if something like "\nYo" is generated just before
+            # "\nYou:" is completed
+            tmp = f"\n{other}:"
+            substring_found = False
+            for j in range(1, len(tmp)):
+                if reply[-j:] == tmp[:j]:
+                    substring_found = True
+
+        return reply, next_character_found, substring_found
 
     def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
         original_text = text
@@ -400,21 +408,25 @@ if args.chat or args.cai_chat:
         history['visible'].append(['', ''])
         eos_token = '\n' if check else None
         for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"):
-            reply, next_character_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True)
+            reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True)
             history['internal'][-1] = [text, reply]
             history['visible'][-1] = [original_text, apply_extensions(reply, "output")]
-            yield history['visible']
+            if not substring_found:
+                yield history['visible']
             if next_character_found:
                 break
+        yield history['visible']
 
     def impersonate_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
         question = generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=True)
         eos_token = '\n' if check else None
         for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name2}:"):
-            reply, next_character_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False)
-            yield apply_extensions(reply, "output")
+            reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False)
+            if not substring_found:
+                yield apply_extensions(reply, "output")
             if next_character_found:
                 break
+        yield apply_extensions(reply, "output")
 
     def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
         for _history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):