|
@@ -2,6 +2,7 @@ import json
|
|
|
import sys
|
|
import sys
|
|
|
import threading
|
|
import threading
|
|
|
import time
|
|
import time
|
|
|
|
|
+import traceback
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import gradio as gr
|
|
import gradio as gr
|
|
@@ -184,7 +185,13 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|
|
bias="none",
|
|
bias="none",
|
|
|
task_type="CAUSAL_LM"
|
|
task_type="CAUSAL_LM"
|
|
|
)
|
|
)
|
|
|
- lora_model = get_peft_model(shared.model, config)
|
|
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ lora_model = get_peft_model(shared.model, config)
|
|
|
|
|
+ except:
|
|
|
|
|
+ yield traceback.format_exc()
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
trainer = transformers.Trainer(
|
|
trainer = transformers.Trainer(
|
|
|
model=lora_model,
|
|
model=lora_model,
|
|
|
train_dataset=train_data,
|
|
train_dataset=train_data,
|