瀏覽代碼

Fix broken callbacks.py

oobabooga 2 年之前
父節點
當前提交
d1327f99f9
共有 1 個文件被更改,包括 5 次插入2 次删除
  1. 5 2
      modules/callbacks.py

+ 5 - 2
modules/callbacks.py

@@ -4,8 +4,6 @@ from threading import Thread
 import torch
 import transformers
 
-from modules.text_generation import clear_torch_cache
-
 
 # Copied from https://github.com/PygmalionAI/gradio-ui/
 class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
@@ -89,3 +87,8 @@ class Iteratorize:
     def __exit__(self, exc_type, exc_val, exc_tb):
         self.stop_now = True
         clear_torch_cache()
+
+def clear_torch_cache():
+    gc.collect()
+    if not shared.args.cpu:
+        torch.cuda.empty_cache()