瀏覽代碼

Handle training exception for unsupported models

oobabooga 2 年之前
父節點
當前提交
58349f44a0
共有 1 個文件被更改,包括 8 次插入1 次删除
  1. 8 1
      modules/training.py

+ 8 - 1
modules/training.py

@@ -2,6 +2,7 @@ import json
 import sys
 import threading
 import time
+import traceback
 from pathlib import Path
 
 import gradio as gr
@@ -184,7 +185,13 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
         bias="none",
         task_type="CAUSAL_LM"
     )
-    lora_model = get_peft_model(shared.model, config)
+
+    try:
+        lora_model = get_peft_model(shared.model, config)
+    except:
+        yield traceback.format_exc()
+        return
+
     trainer = transformers.Trainer(
         model=lora_model,
         train_dataset=train_data,