Explorar o código

Proper way to free the cuda cache

oobabooga %!s(int64=2) %!d(string=hai) anos
pai
achega
fa58fd5559
Modificáronse 1 ficheiros con 4 adicións e 1 borrados
  1. 4 1
      modules/text_generation.py

+ 4 - 1
modules/text_generation.py

@@ -1,3 +1,4 @@
+import gc
 import re
 import time
 
@@ -73,7 +74,9 @@ def formatted_outputs(reply, model_name):
         return reply
 
 def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
-    torch.cuda.empty_cache()
+    gc.collect()
+    if not shared.args.cpu:
+        torch.cuda.empty_cache()
 
     original_question = question
     if not (shared.args.chat or shared.args.cai_chat):