Instruction Character Vicuna, Instruction Mode Bugfix (#838)
This commit is contained in:
@@ -99,6 +99,11 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
||||
return reply, next_character_found
|
||||
|
||||
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
|
||||
if mode == 'instruct':
|
||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||
else:
|
||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||
|
||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||
name1_original = name1
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
@@ -133,7 +138,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
just_started = True
|
||||
for i in range(generate_state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
reply = cumulative_reply + reply
|
||||
|
||||
# Extracting the reply
|
||||
@@ -163,6 +168,11 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
yield shared.history['visible']
|
||||
|
||||
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
if mode == 'instruct':
|
||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||
else:
|
||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||
|
||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
name1 = "You"
|
||||
@@ -175,7 +185,7 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
|
||||
cumulative_reply = ''
|
||||
for i in range(generate_state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
reply = cumulative_reply + reply
|
||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
|
||||
yield reply
|
||||
|
||||
Reference in New Issue
Block a user