Просмотр исходного кода

Prevent *Is typing* from disappearing instantly while streaming

oobabooga 2 лет назад
Родитель
Сommit
cf2da86352
1 измененных файлов с 4 добавлено и 2 удалено
  1. 4 2
      modules/text_generation.py

+ 4 - 2
modules/text_generation.py

@@ -101,7 +101,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
                 reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
                 reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
                 yield formatted_outputs(reply, shared.model_name)
                 yield formatted_outputs(reply, shared.model_name)
             else:
             else:
-                yield formatted_outputs(question, shared.model_name)
+                if not (shared.args.chat or shared.args.cai_chat):
+                    yield formatted_outputs(question, shared.model_name)
                 # RWKV has proper streaming, which is very nice.
                 # RWKV has proper streaming, which is very nice.
                 # No need to generate 8 tokens at a time.
                 # No need to generate 8 tokens at a time.
                 for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
                 for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
@@ -197,7 +198,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
             def generate_with_streaming(**kwargs):
             def generate_with_streaming(**kwargs):
                 return Iteratorize(generate_with_callback, kwargs, callback=None)
                 return Iteratorize(generate_with_callback, kwargs, callback=None)
 
 
-            yield formatted_outputs(original_question, shared.model_name)
+            if not (shared.args.chat or shared.args.cai_chat):
+                yield formatted_outputs(original_question, shared.model_name)
             with generate_with_streaming(**generate_params) as generator:
             with generate_with_streaming(**generate_params) as generator:
                 for output in generator:
                 for output in generator:
                     if shared.soft_prompt:
                     if shared.soft_prompt: