Parcourir la source

Instruction Character Vicuna, Instruction Mode Bugfix (#838)

OWKenobi il y a 2 ans
Parent
commit
310bf46a94
2 fichiers modifiés avec 15 ajouts et 2 suppressions
  1. 3 0
      characters/instruction-following/Vicuna.yaml
  2. 12 2
      modules/chat.py

+ 3 - 0
characters/instruction-following/Vicuna.yaml

@@ -0,0 +1,3 @@
+name: "### Assistant:"
+your_name: "### Human:"
+context: "Below is an instruction that describes a task. Write a response that appropriately completes the request."

+ 12 - 2
modules/chat.py

@@ -99,6 +99,11 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
     return reply, next_character_found
 
 def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
+    if mode == 'instruct':
+        stopping_strings = [f"\n{name1}", f"\n{name2}"]
+    else:
+        stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
+        
     eos_token = '\n' if generate_state['stop_at_newline'] else None
     name1_original = name1
     if 'pygmalion' in shared.model_name.lower():
@@ -133,7 +138,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
     just_started = True
     for i in range(generate_state['chat_generation_attempts']):
         reply = None
-        for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+        for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
             reply = cumulative_reply + reply
 
             # Extracting the reply
@@ -163,6 +168,11 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
     yield shared.history['visible']
 
 def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
+    if mode == 'instruct':
+        stopping_strings = [f"\n{name1}", f"\n{name2}"]
+    else:
+        stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
+        
     eos_token = '\n' if generate_state['stop_at_newline'] else None
     if 'pygmalion' in shared.model_name.lower():
         name1 = "You"
@@ -175,7 +185,7 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
     cumulative_reply = ''
     for i in range(generate_state['chat_generation_attempts']):
         reply = None
-        for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+        for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
             reply = cumulative_reply + reply
             reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
             yield reply