|
|
@@ -1,10 +1,17 @@
|
|
|
-import sys, torch, json, threading, time
|
|
|
+import json
|
|
|
+import sys
|
|
|
+import threading
|
|
|
+import time
|
|
|
from pathlib import Path
|
|
|
+
|
|
|
import gradio as gr
|
|
|
-from datasets import load_dataset
|
|
|
+import torch
|
|
|
import transformers
|
|
|
-from modules import ui, shared
|
|
|
-from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict
|
|
|
+from datasets import load_dataset
|
|
|
+from peft import (LoraConfig, get_peft_model, get_peft_model_state_dict,
|
|
|
+ prepare_model_for_int8_training)
|
|
|
+
|
|
|
+from modules import shared, ui
|
|
|
|
|
|
WANT_INTERRUPT = False
|
|
|
CURRENT_STEPS = 0
|
|
|
@@ -44,7 +51,7 @@ def create_train_interface():
|
|
|
with gr.Row():
|
|
|
startButton = gr.Button("Start LoRA Training")
|
|
|
stopButton = gr.Button("Interrupt")
|
|
|
- output = gr.Markdown(value="(...)")
|
|
|
+ output = gr.Markdown(value="Ready")
|
|
|
startEvent = startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output])
|
|
|
stopButton.click(doInterrupt, [], [], cancels=[], queue=False)
|
|
|
|
|
|
@@ -169,16 +176,20 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
|
|
).__get__(loraModel, type(loraModel))
|
|
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
|
|
loraModel = torch.compile(loraModel)
|
|
|
+
|
|
|
# == Main run and monitor loop ==
|
|
|
# TODO: save/load checkpoints to resume from?
|
|
|
print("Starting training...")
|
|
|
yield "Starting..."
|
|
|
+
|
|
|
def threadedRun():
|
|
|
trainer.train()
|
|
|
+
|
|
|
thread = threading.Thread(target=threadedRun)
|
|
|
thread.start()
|
|
|
lastStep = 0
|
|
|
startTime = time.perf_counter()
|
|
|
+
|
|
|
while thread.is_alive():
|
|
|
time.sleep(0.5)
|
|
|
if WANT_INTERRUPT:
|
|
|
@@ -197,8 +208,10 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
|
|
timerInfo = f"`{1.0/its:.2f}` s/it"
|
|
|
totalTimeEstimate = (1.0/its) * (MAX_STEPS)
|
|
|
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds"
|
|
|
+
|
|
|
print("Training complete, saving...")
|
|
|
loraModel.save_pretrained(loraName)
|
|
|
+
|
|
|
if WANT_INTERRUPT:
|
|
|
print("Training interrupted.")
|
|
|
yield f"Interrupted. Incomplete LoRA saved to `{loraName}`"
|