Просмотр исходного кода

fix random seeds to actually randomize (#1004 from mcmonkey4eva/seed-fix)

oobabooga 2 лет назад
Родитель
Сommit
843f672227
1 измененных файлов с 11 добавлено и 7 удалено
  1. 11 7
      modules/text_generation.py

+ 11 - 7
modules/text_generation.py

@@ -1,3 +1,4 @@
+import random
 import re
 import time
 import traceback
@@ -97,10 +98,13 @@ def formatted_outputs(reply, model_name):
 
 
 def set_manual_seed(seed):
-    if seed != -1:
-        torch.manual_seed(seed)
-        if torch.cuda.is_available():
-            torch.cuda.manual_seed_all(seed)
+    seed = int(seed)
+    if seed == -1:
+        seed = random.randint(1, 2**31)
+    torch.manual_seed(seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed_all(seed)
+    return seed
 
 
 def stop_everything_event():
@@ -109,7 +113,7 @@ def stop_everything_event():
 
 def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
     clear_torch_cache()
-    set_manual_seed(generate_state['seed'])
+    seed = set_manual_seed(generate_state['seed'])
     shared.stop_everything = False
     generate_params = {}
     t0 = time.time()
@@ -151,7 +155,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
             t1 = time.time()
             original_tokens = len(encode(original_question)[0])
             new_tokens = len(encode(output)[0]) - original_tokens
-            print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})')
+            print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
             return
 
     input_ids = encode(question, generate_state['max_new_tokens'])
@@ -272,5 +276,5 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
         t1 = time.time()
         original_tokens = len(original_input_ids[0])
         new_tokens = len(output) - original_tokens
-        print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})')
+        print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
         return