Просмотр исходного кода

Print the performance information more reliably

oobabooga 2 лет назад
Родитель
Сommit
05e703b4a4
1 измененных файлов с 8 добавлено и 3 удалено
  1. 8 3
      modules/text_generation.py

+ 8 - 3
modules/text_generation.py

@@ -86,12 +86,18 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
     if not shared.args.cpu:
         torch.cuda.empty_cache()
 
+    t0 = time.time()
+
+    # These models are not part of Hugging Face, so we handle them
+    # separately and terminate the function call earlier
     if shared.is_RWKV or shared.is_LLaMA:
         if shared.args.no_stream:
             reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature, top_p=top_p)
+            t1 = time.time()
+            print(f"Output generated in {(t1-t0):.2f} seconds.")
             yield formatted_outputs(reply, shared.model_name)
         else:
-            for i in range(max_new_tokens//8):
+            for i in tqdm(range(max_new_tokens//8+1)):
                 reply = shared.model.generate(question, token_count=8, temperature=temperature, top_p=top_p)
                 yield formatted_outputs(reply, shared.model_name)
                 question = reply
@@ -160,7 +166,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
 
     # Generate the entire reply at once
     if shared.args.no_stream:
-        t0 = time.time()
         with torch.no_grad():
             output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
         if shared.soft_prompt:
@@ -169,10 +174,10 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
         reply = decode(output)
         if not (shared.args.chat or shared.args.cai_chat):
             reply = original_question + apply_extensions(reply[len(question):], "output")
-        yield formatted_outputs(reply, shared.model_name)
 
         t1 = time.time()
         print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
+        yield formatted_outputs(reply, shared.model_name)
 
     # Generate the reply 8 tokens at a time
     else: