Jelajahi Sumber

Fix broken callbacks.py

oobabooga 2 tahun lalu
induk
melakukan
d1327f99f9
1 mengubah file dengan 5 tambahan dan 2 penghapusan
  1. 5 2
      modules/callbacks.py

+ 5 - 2
modules/callbacks.py

@@ -4,8 +4,6 @@ from threading import Thread
 import torch
 import transformers
 
-from modules.text_generation import clear_torch_cache
-
 
 # Copied from https://github.com/PygmalionAI/gradio-ui/
 class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
@@ -89,3 +87,8 @@ class Iteratorize:
     def __exit__(self, exc_type, exc_val, exc_tb):
         self.stop_now = True
         clear_torch_cache()
+
+def clear_torch_cache():
+    gc.collect()
+    if not shared.args.cpu:
+        torch.cuda.empty_cache()