소스 검색

Handle training exception for unsupported models

oobabooga 2 년 전
부모
커밋
58349f44a0
1개의 변경된 파일8개의 추가작업 그리고 1개의 파일을 삭제
  1. 8 1
      modules/training.py

+ 8 - 1
modules/training.py

@@ -2,6 +2,7 @@ import json
 import sys
 import threading
 import time
+import traceback
 from pathlib import Path
 
 import gradio as gr
@@ -184,7 +185,13 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
         bias="none",
         task_type="CAUSAL_LM"
     )
-    lora_model = get_peft_model(shared.model, config)
+
+    try:
+        lora_model = get_peft_model(shared.model, config)
+    except:
+        yield traceback.format_exc()
+        return
+
     trainer = transformers.Trainer(
         model=lora_model,
         train_dataset=train_data,