|
|
@@ -99,6 +99,11 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
|
|
return reply, next_character_found
|
|
|
|
|
|
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
|
|
|
+ if mode == 'instruct':
|
|
|
+ stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
|
|
+ else:
|
|
|
+ stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
|
|
+
|
|
|
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
|
|
name1_original = name1
|
|
|
if 'pygmalion' in shared.model_name.lower():
|
|
|
@@ -133,7 +138,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|
|
just_started = True
|
|
|
for i in range(generate_state['chat_generation_attempts']):
|
|
|
reply = None
|
|
|
- for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
|
|
|
+ for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
|
|
reply = cumulative_reply + reply
|
|
|
|
|
|
# Extracting the reply
|
|
|
@@ -163,6 +168,11 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|
|
yield shared.history['visible']
|
|
|
|
|
|
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
|
|
+ if mode == 'instruct':
|
|
|
+ stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
|
|
+ else:
|
|
|
+ stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
|
|
+
|
|
|
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
|
|
if 'pygmalion' in shared.model_name.lower():
|
|
|
name1 = "You"
|
|
|
@@ -175,7 +185,7 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
|
|
|
cumulative_reply = ''
|
|
|
for i in range(generate_state['chat_generation_attempts']):
|
|
|
reply = None
|
|
|
- for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
|
|
|
+ for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
|
|
reply = cumulative_reply + reply
|
|
|
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
|
|
|
yield reply
|