Browse Source

Handle training exception for unsupported models

oobabooga 2 năm trước cách đây
mục cha
commit
58349f44a0
1 tập tin đã thay đổi với 8 bổ sung1 xóa
  1. 8 1
      modules/training.py

+ 8 - 1
modules/training.py

@@ -2,6 +2,7 @@ import json
 import sys
 import sys
 import threading
 import threading
 import time
 import time
+import traceback
 from pathlib import Path
 from pathlib import Path
 
 
 import gradio as gr
 import gradio as gr
@@ -184,7 +185,13 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
         bias="none",
         bias="none",
         task_type="CAUSAL_LM"
         task_type="CAUSAL_LM"
     )
     )
-    lora_model = get_peft_model(shared.model, config)
+
+    try:
+        lora_model = get_peft_model(shared.model, config)
+    except:
+        yield traceback.format_exc()
+        return
+
     trainer = transformers.Trainer(
     trainer = transformers.Trainer(
         model=lora_model,
         model=lora_model,
         train_dataset=train_data,
         train_dataset=train_data,