Explorar el Código

Fix chat_generation_attempts

oobabooga hace 2 años
padre
commit
4f5c2ce785
Se han modificado 1 ficheros con 13 adiciones y 5 borrados
  1. 13 5
      modules/chat.py

+ 13 - 5
modules/chat.py

@@ -115,9 +115,10 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
         yield shared.history['visible']+[[visible_text, shared.processing_message]]
 
     # Generate
-    reply = ''
+    cumulative_reply = ''
     for i in range(chat_generation_attempts):
-        for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+        for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+            reply = cumulative_reply + reply
 
             # Extracting the reply
             reply, next_character_found = extract_message_from_reply(reply, name1, name2, check)
@@ -142,6 +143,8 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
             if next_character_found:
                 break
 
+        cumulative_reply = reply
+
     yield shared.history['visible']
 
 def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
@@ -152,16 +155,21 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
 
     prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
 
-    reply = ''
     # Yield *Is typing...*
     yield shared.processing_message
+
+    cumulative_reply = ''
     for i in range(chat_generation_attempts):
-        for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+        for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+            reply = cumulative_reply + reply
             reply, next_character_found = extract_message_from_reply(reply, name1, name2, check)
             yield reply
             if next_character_found:
                 break
-        yield reply
+
+        cumulative_reply = reply
+
+    yield reply
 
 def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
     for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):