Parcourir la source

Properly count tokens/s for llama.cpp in chat mode

oobabooga il y a 2 ans
Parent
commit
0aee7341d8
1 fichiers modifiés avec 8 ajouts et 4 suppressions
  1. 8 4
      modules/text_generation.py

+ 8 - 4
modules/text_generation.py

@@ -120,6 +120,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
         try:
             if shared.args.no_stream:
                 reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
+                output = original_question+reply
                 if not (shared.args.chat or shared.args.cai_chat):
                     reply = original_question + apply_extensions(reply, "output")
                 yield formatted_outputs(reply, shared.model_name)
@@ -130,6 +131,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
                 # RWKV has proper streaming, which is very nice.
                 # No need to generate 8 tokens at a time.
                 for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
+                    output = original_question+reply
                     if not (shared.args.chat or shared.args.cai_chat):
                         reply = original_question + apply_extensions(reply, "output")
                     yield formatted_outputs(reply, shared.model_name)
@@ -138,9 +140,9 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
             traceback.print_exc()
         finally:
             t1 = time.time()
-            output = encode(reply)[0]
-            input_ids = encode(question)
-            print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
+            original_tokens = len(encode(original_question)[0])
+            new_tokens = len(encode(output)[0]) - original_tokens
+            print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
             return
 
     input_ids = encode(question, max_new_tokens)
@@ -272,5 +274,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
         traceback.print_exc()
     finally:
         t1 = time.time()
-        print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens, context {len(original_input_ids[0])})")
+        original_tokens = len(original_input_ids[0])
+        new_tokens = len(output)-original_tokens
+        print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
         return