oobabooga 2 лет назад
Родитель
Сommit
955cf431e8
1 измененных файлов с 5 добавлено и 5 удалено
  1. 5 5
      modules/text_generation.py

+ 5 - 5
modules/text_generation.py

@@ -68,10 +68,10 @@ def fix_galactica(s):
 
 def formatted_outputs(reply, model_name):
     if not (shared.args.chat or shared.args.cai_chat):
-        if shared.model_name.lower().startswith('galactica'):
+        if model_name.lower().startswith('galactica'):
             reply = fix_galactica(reply)
             return reply, reply, generate_basic_html(reply)
-        elif shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
+        elif model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
             reply = fix_gpt4chan(reply)
             return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
         else:
@@ -87,13 +87,13 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
     if shared.is_RWKV:
         if shared.args.no_stream:
             reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature, top_p=top_p)
-            yield formatted_outputs(reply, None)
+            yield formatted_outputs(reply, shared.model_name)
         else:
             for i in range(max_new_tokens//8):
                 reply = shared.model.generate(question, token_count=8, temperature=temperature, top_p=top_p)
-                yield formatted_outputs(reply, None)
+                yield formatted_outputs(reply, shared.model_name)
                 question = reply
-        return formatted_outputs(reply, None)
+        return formatted_outputs(reply, shared.model_name)
 
     original_question = question
     if not (shared.args.chat or shared.args.cai_chat):