Forráskód Böngészése

Clear cache while switching LoRAs

oobabooga 2 éve
szülő
commit
bf22d16ebc
3 módosított fájl, 13 hozzáadás és 24 törlés
  1. 9 6
      modules/LoRA.py
  2. 1 7
      modules/callbacks.py
  3. 3 11
      server.py

+ 9 - 6
modules/LoRA.py

@@ -2,19 +2,22 @@ from pathlib import Path
 
 import modules.shared as shared
 from modules.models import load_model
+from modules.text_generation import clear_torch_cache
 
 
+def reload_model():
+    shared.model = shared.tokenizer = None
+    clear_torch_cache()
+    shared.model, shared.tokenizer = load_model(shared.model_name)
+
 def add_lora_to_model(lora_name):
 
     from peft import PeftModel
 
-    # Is there a more efficient way of returning to the base model?
-    if lora_name == "None":
-        print("Reloading the model to remove the LoRA...")
-        shared.model, shared.tokenizer = load_model(shared.model_name)
-    else:
+    reload_model()
+
+    if lora_name != "None":
         print(f"Adding the LoRA {lora_name} to the model...")
-        
         params = {}
         if not shared.args.cpu:
             params['dtype'] = shared.model.dtype

+ 1 - 7
modules/callbacks.py

@@ -1,11 +1,10 @@
-import gc
 from queue import Queue
 from threading import Thread
 
 import torch
 import transformers
 
-import modules.shared as shared
+from modules.text_generation import clear_torch_cache
 
 
 # Copied from https://github.com/PygmalionAI/gradio-ui/
@@ -90,8 +89,3 @@ class Iteratorize:
     def __exit__(self, exc_type, exc_val, exc_tb):
         self.stop_now = True
         clear_torch_cache()
-
-def clear_torch_cache():
-    gc.collect()
-    if not shared.args.cpu:
-        torch.cuda.empty_cache()

+ 3 - 11
server.py

@@ -1,4 +1,3 @@
-import gc
 import io
 import json
 import re
@@ -8,7 +7,6 @@ import zipfile
 from pathlib import Path
 
 import gradio as gr
-import torch
 
 import modules.chat as chat
 import modules.extensions as extensions_module
@@ -17,7 +15,7 @@ import modules.ui as ui
 from modules.html_generator import generate_chat_html
 from modules.LoRA import add_lora_to_model
 from modules.models import load_model, load_soft_prompt
-from modules.text_generation import generate_reply
+from modules.text_generation import clear_torch_cache, generate_reply
 
 # Loading custom settings
 settings_file = None
@@ -56,21 +54,15 @@ def load_model_wrapper(selected_model):
     if selected_model != shared.model_name:
         shared.model_name = selected_model
         shared.model = shared.tokenizer = None
-        if not shared.args.cpu:
-            gc.collect()
-            torch.cuda.empty_cache()
+        clear_torch_cache()
         shared.model, shared.tokenizer = load_model(shared.model_name)
 
     return selected_model
 
 def load_lora_wrapper(selected_lora):
     shared.lora_name = selected_lora
-    default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
-
-    if not shared.args.cpu:
-        gc.collect()
-        torch.cuda.empty_cache()
     add_lora_to_model(selected_lora)
+    default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
 
     return selected_lora, default_text