|
|
@@ -1,4 +1,3 @@
|
|
|
-import gc
|
|
|
import io
|
|
|
import json
|
|
|
import re
|
|
|
@@ -8,7 +7,6 @@ import zipfile
|
|
|
from pathlib import Path
|
|
|
|
|
|
import gradio as gr
|
|
|
-import torch
|
|
|
|
|
|
import modules.chat as chat
|
|
|
import modules.extensions as extensions_module
|
|
|
@@ -17,7 +15,7 @@ import modules.ui as ui
|
|
|
from modules.html_generator import generate_chat_html
|
|
|
from modules.LoRA import add_lora_to_model
|
|
|
from modules.models import load_model, load_soft_prompt
|
|
|
-from modules.text_generation import generate_reply
|
|
|
+from modules.text_generation import clear_torch_cache, generate_reply
|
|
|
|
|
|
# Loading custom settings
|
|
|
settings_file = None
|
|
|
@@ -56,21 +54,15 @@ def load_model_wrapper(selected_model):
|
|
|
if selected_model != shared.model_name:
|
|
|
shared.model_name = selected_model
|
|
|
shared.model = shared.tokenizer = None
|
|
|
- if not shared.args.cpu:
|
|
|
- gc.collect()
|
|
|
- torch.cuda.empty_cache()
|
|
|
+ clear_torch_cache()
|
|
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
|
|
|
|
|
return selected_model
|
|
|
|
|
|
def load_lora_wrapper(selected_lora):
|
|
|
shared.lora_name = selected_lora
|
|
|
- default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
|
|
|
-
|
|
|
- if not shared.args.cpu:
|
|
|
- gc.collect()
|
|
|
- torch.cuda.empty_cache()
|
|
|
add_lora_to_model(selected_lora)
|
|
|
+ default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
|
|
|
|
|
|
return selected_lora, default_text
|
|
|
|