|
|
@@ -1,3 +1,4 @@
|
|
|
+import random
|
|
|
import re
|
|
|
import time
|
|
|
import traceback
|
|
|
@@ -97,10 +98,11 @@ def formatted_outputs(reply, model_name):
|
|
|
|
|
|
|
|
|
def set_manual_seed(seed):
|
|
|
- if seed != -1:
|
|
|
- torch.manual_seed(seed)
|
|
|
- if torch.cuda.is_available():
|
|
|
- torch.cuda.manual_seed_all(seed)
|
|
|
+ if seed == -1:
|
|
|
+ seed = random.randint(1, 2**31)
|
|
|
+ torch.manual_seed(seed)
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
|
|
|
def stop_everything_event():
|