Просмотр исходного кода

Add also_return_rows to generate_chat_prompt

oobabooga 2 лет назад
Родитель
Сommit
fcda3f8776
1 измененных файлов с 6 добавлено и 2 удалено
  1. 6 2
      modules/chat.py

+ 6 - 2
modules/chat.py

@@ -22,7 +22,7 @@ def generate_chat_output(history, name1, name2, character):
     else:
     else:
         return history
         return history
 
 
-def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False):
+def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False, also_return_rows=False):
     user_input = fix_newlines(user_input)
     user_input = fix_newlines(user_input)
     rows = [f"{context.strip()}\n"]
     rows = [f"{context.strip()}\n"]
 
 
@@ -51,7 +51,11 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
         rows.pop(1)
         rows.pop(1)
 
 
     prompt = ''.join(rows)
     prompt = ''.join(rows)
-    return prompt
+
+    if also_return_rows:
+        return prompt, rows
+    else:
+        return prompt
 
 
 def extract_message_from_reply(reply, name1, name2, check):
 def extract_message_from_reply(reply, name1, name2, check):
     next_character_found = False
     next_character_found = False