Quellcode durchsuchen

Improve the text generation call a bit

oobabooga vor 2 Jahren
Ursprung
Commit
7c4d5ca8cc
2 geänderte Dateien mit 4 neuen und 5 gelöschten Zeilen
  1. 1 1
      modules/RWKV.py
  2. 3 4
      modules/text_generation.py

+ 1 - 1
modules/RWKV.py

@@ -42,4 +42,4 @@ class RWKVModel:
             token_stop = token_stop
         )
 
-        return self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
+        return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)

+ 3 - 4
modules/text_generation.py

@@ -86,15 +86,14 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
 
     if shared.is_RWKV:
         if shared.args.no_stream:
-            reply = question + shared.model.generate(question, token_count=max_new_tokens, temperature=temperature)
+            reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature)
             yield formatted_outputs(reply, None)
-            return formatted_outputs(reply, None)
         else:
             for i in range(max_new_tokens//8):
-                reply = question + shared.model.generate(question, token_count=8, temperature=temperature)
+                reply = shared.model.generate(question, token_count=8, temperature=temperature)
                 yield formatted_outputs(reply, None)
                 question = reply
-            return formatted_outputs(reply, None)
+        return formatted_outputs(reply, None)
 
     original_question = question
     if not (shared.args.chat or shared.args.cai_chat):