|
@@ -390,7 +390,15 @@ if args.chat or args.cai_chat:
|
|
|
next_character_found = True
|
|
next_character_found = True
|
|
|
reply = clean_chat_message(reply)
|
|
reply = clean_chat_message(reply)
|
|
|
|
|
|
|
|
- return reply, next_character_found
|
|
|
|
|
|
|
+ # Detect if something like "\nYo" is generated just before
|
|
|
|
|
+ # "\nYou:" is completed
|
|
|
|
|
+ tmp = f"\n{other}:"
|
|
|
|
|
+ substring_found = False
|
|
|
|
|
+ for j in range(1, len(tmp)):
|
|
|
|
|
+ if reply[-j:] == tmp[:j]:
|
|
|
|
|
+ substring_found = True
|
|
|
|
|
+
|
|
|
|
|
+ return reply, next_character_found, substring_found
|
|
|
|
|
|
|
|
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
|
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
|
|
original_text = text
|
|
original_text = text
|
|
@@ -400,21 +408,25 @@ if args.chat or args.cai_chat:
|
|
|
history['visible'].append(['', ''])
|
|
history['visible'].append(['', ''])
|
|
|
eos_token = '\n' if check else None
|
|
eos_token = '\n' if check else None
|
|
|
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
|
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
|
|
- reply, next_character_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True)
|
|
|
|
|
|
|
+ reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True)
|
|
|
history['internal'][-1] = [text, reply]
|
|
history['internal'][-1] = [text, reply]
|
|
|
history['visible'][-1] = [original_text, apply_extensions(reply, "output")]
|
|
history['visible'][-1] = [original_text, apply_extensions(reply, "output")]
|
|
|
- yield history['visible']
|
|
|
|
|
|
|
+ if not substring_found:
|
|
|
|
|
+ yield history['visible']
|
|
|
if next_character_found:
|
|
if next_character_found:
|
|
|
break
|
|
break
|
|
|
|
|
+ yield history['visible']
|
|
|
|
|
|
|
|
def impersonate_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
|
def impersonate_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
|
|
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=True)
|
|
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=True)
|
|
|
eos_token = '\n' if check else None
|
|
eos_token = '\n' if check else None
|
|
|
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name2}:"):
|
|
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name2}:"):
|
|
|
- reply, next_character_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False)
|
|
|
|
|
- yield apply_extensions(reply, "output")
|
|
|
|
|
|
|
+ reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False)
|
|
|
|
|
+ if not substring_found:
|
|
|
|
|
+ yield apply_extensions(reply, "output")
|
|
|
if next_character_found:
|
|
if next_character_found:
|
|
|
break
|
|
break
|
|
|
|
|
+ yield apply_extensions(reply, "output")
|
|
|
|
|
|
|
|
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
|
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
|
|
for _history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
|
for _history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|