Parcourir la source

Proper way to free the cuda cache

oobabooga il y a 2 ans
Parent
commit
fa58fd5559
1 fichiers modifiés avec 4 ajouts et 1 suppressions
  1. 4 1
      modules/text_generation.py

+ 4 - 1
modules/text_generation.py

@@ -1,3 +1,4 @@
+import gc
 import re
 import time
 
@@ -73,7 +74,9 @@ def formatted_outputs(reply, model_name):
         return reply
 
 def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
-    torch.cuda.empty_cache()
+    gc.collect()
+    if not shared.args.cpu:
+        torch.cuda.empty_cache()
 
     original_question = question
     if not (shared.args.chat or shared.args.cai_chat):