Kaynağa Gözat

Implement regenerate/impersonate the proper way (fixes #78)

oobabooga 3 yıl önce
ebeveyn
işleme
b3bcd2881d
1 değiştirilmiş dosya ile 11 ekleme ve 13 silme
  1. 11 13
      server.py

+ 11 - 13
server.py

@@ -599,28 +599,26 @@ def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, to
     for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
         reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False)
         if not substring_found:
-            yield apply_extensions(reply, "output")
+            yield reply
         if next_character_found:
             break
-    yield apply_extensions(reply, "output")
+    yield reply
 
 def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
     for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
         yield generate_chat_html(_history, name1, name2, character)
 
 def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
-    last = history['visible'].pop()
+    last_visible = history['visible'].pop()
+    last_internal = history['internal'].pop()
 
-    # Fix for when the last sent message was an image
-    if last[0].startswith('<img src="'):
-        last[0] = history['internal'].pop()[0]
-    else:
-        history['internal'].pop()
-
-    text = last[0]
-    function_call = "cai_chatbot_wrapper" if args.cai_chat else "chatbot_wrapper"
-    for i in eval(function_call)(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
-        yield i
+    for _history in chatbot_wrapper(last_internal[0], tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
+        if args.cai_chat:
+            history['visible'][-1] = [last_visible[0], _history[-1][1]]
+            yield generate_chat_html(history['visible'], name1, name2, character)
+        else:
+            history['visible'][-1] = (last_visible[0], _history[-1][1])
+            yield history['visible']
 
 def remove_last_message(name1, name2):
     if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':