Jelajahi Sumber

Don't regenerate if no message has been sent

oobabooga 3 tahun lalu
induk
melakukan
589069e105
1 mengubah file dengan 12 tambahan dan 6 penghapusan
  1. 12 6
      server.py

+ 12 - 6
server.py

@@ -612,16 +612,22 @@ def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, to
         yield generate_chat_html(_history, name1, name2, character)
 
 def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
-    last_visible = history['visible'].pop()
-    last_internal = history['internal'].pop()
-
-    for _history in chatbot_wrapper(last_internal[0], tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
+    if character is not None and len(history['visible']) == 1:
         if args.cai_chat:
-            history['visible'][-1] = [last_visible[0], _history[-1][1]]
             yield generate_chat_html(history['visible'], name1, name2, character)
         else:
-            history['visible'][-1] = (last_visible[0], _history[-1][1])
             yield history['visible']
+    else:
+        last_visible = history['visible'].pop()
+        last_internal = history['internal'].pop()
+
+        for _history in chatbot_wrapper(last_internal[0], tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
+            if args.cai_chat:
+                history['visible'][-1] = [last_visible[0], _history[-1][1]]
+                yield generate_chat_html(history['visible'], name1, name2, character)
+            else:
+                history['visible'][-1] = (last_visible[0], _history[-1][1])
+                yield history['visible']
 
 def remove_last_message(name1, name2):
     if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':