Sfoglia il codice sorgente

initial progress tracker in UI

Alex "mcmonkey" Goodwin 2 anni fa
parent
commit
8fc723fc95
1 ha cambiato i file con 40 aggiunte e 8 eliminazioni
  1. 40 8
      modules/training.py

+ 40 - 8
modules/training.py

@@ -1,4 +1,4 @@
-import sys, torch, json
+import sys, torch, json, threading, time
 from pathlib import Path
 import gradio as gr
 from datasets import load_dataset
@@ -6,6 +6,9 @@ import transformers
 from modules import ui, shared
 from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict
 
+CURRENT_STEPS = 0
+MAX_STEPS = 0
+
 def get_json_dataset(path: str):
     def get_set():
         return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path(path).glob('*.json'))), key=str.lower)
@@ -40,6 +43,12 @@ def create_train_interface():
         output = gr.Markdown(value="(...)")
         startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output])
 
+class Callbacks(transformers.TrainerCallback):
+    def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
+        global CURRENT_STEPS, MAX_STEPS
+        CURRENT_STEPS = state.global_step
+        MAX_STEPS = state.max_steps
+
 def cleanPath(basePath: str, path: str):
     """"Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
     # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
@@ -50,8 +59,11 @@ def cleanPath(basePath: str, path: str):
     return f'{Path(basePath).absolute()}/{path}'
 
 def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, learningRate: float, loraRank: int, loraAlpha: int, loraDropout: float, cutoffLen: int, dataset: str, evalDataset: str, format: str):
+    global CURRENT_STEPS, MAX_STEPS
+    CURRENT_STEPS = 0
+    MAX_STEPS = 0
     yield "Prepping..."
-    # Input validation / processing
+    # == Input validation / processing ==
     # TODO: --lora-dir PR once pulled will need to be applied here
     loraName = f"loras/{cleanPath(None, loraName)}"
     if dataset is None:
@@ -62,7 +74,7 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
     actualLR = float(learningRate)
     shared.tokenizer.pad_token = 0
     shared.tokenizer.padding_side = "left"
-    # Prep the dataset, format, etc
+    # == Prep the dataset, format, etc ==
     with open(cleanPath('training/formats', f'{format}.json'), 'r') as formatFile:
         formatData: dict[str, str] = json.load(formatFile)
     def tokenize(prompt):
@@ -89,7 +101,7 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
     else:
         evalData = load_dataset("json", data_files=cleanPath('training/datasets', f'{evalDataset}.json'))
         evalData = evalData['train'].shuffle().map(generate_and_tokenize_prompt)
-    # Start prepping the model itself
+    # == Start prepping the model itself ==
     if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
         print("Getting model ready...")
         prepare_model_for_int8_training(shared.model)
@@ -128,6 +140,7 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
             ddp_find_unused_parameters=None
         ),
         data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
+        callbacks=list([Callbacks()])
     )
     loraModel.config.use_cache = False
     old_state_dict = loraModel.state_dict
@@ -136,12 +149,31 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
     ).__get__(loraModel, type(loraModel))
     if torch.__version__ >= "2" and sys.platform != "win32":
         loraModel = torch.compile(loraModel)
-    # Actually start and run and save at the end
+    # == Main run and monitor loop ==
     # TODO: save/load checkpoints to resume from?
     print("Starting training...")
-    yield "Running..."
-    trainer.train()
+    yield "Starting..."
+    def threadedRun():
+        trainer.train()
+    thread = threading.Thread(target=threadedRun)
+    thread.start()
+    lastStep = 0
+    startTime = time.perf_counter()
+    while thread.is_alive():
+        time.sleep(0.5)
+        if CURRENT_STEPS != lastStep:
+            lastStep = CURRENT_STEPS
+            timeElapsed = time.perf_counter() - startTime
+            if timeElapsed <= 0:
+                timerInfo = ""
+            else:
+                its = CURRENT_STEPS / timeElapsed
+                if its > 1:
+                    timerInfo = f"`{its:.2f}` it/s"
+                else:
+                    timerInfo = f"`{1.0/its:.2f}` s/it"
+            yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.1f}` seconds"
     print("Training complete, saving...")
     loraModel.save_pretrained(loraName)
     print("Training complete!")
-    yield f"Done! Lora saved to `{loraName}`"
+    yield f"Done! LoRA saved to `{loraName}`"