فهرست منبع

interrupt button

Alex "mcmonkey" Goodwin 2 سال پیش
والد
کامیت
16ea4fc36d
1فایلهای تغییر یافته به همراه34 افزوده شده و 8 حذف شده
  1. 34 8
      modules/training.py

+ 34 - 8
modules/training.py

@@ -6,8 +6,10 @@ import transformers
 from modules import ui, shared
 from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict
 
+WANT_INTERRUPT = False
 CURRENT_STEPS = 0
 MAX_STEPS = 0
+CURRENT_GRADIENT_ACCUM = 1
 
 def get_json_dataset(path: str):
     def get_set():
@@ -39,15 +41,31 @@ def create_train_interface():
             formatsFunction = get_json_dataset('training/formats')
             format = gr.Dropdown(choices=formatsFunction(), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
             ui.create_refresh_button(format, lambda : None, lambda : {'choices': formatsFunction()}, 'refresh-button')
-        startButton = gr.Button("Start LoRA Training")
+        with gr.Row():
+            startButton = gr.Button("Start LoRA Training")
+            stopButton = gr.Button("Interrupt")
         output = gr.Markdown(value="(...)")
-        startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output])
+        startEvent = startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output])
+        stopButton.click(doInterrupt, [], [], cancels=[], queue=False)
+
+def doInterrupt():
+    global WANT_INTERRUPT
+    WANT_INTERRUPT = True
 
 class Callbacks(transformers.TrainerCallback):
     def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
         global CURRENT_STEPS, MAX_STEPS
-        CURRENT_STEPS = state.global_step
-        MAX_STEPS = state.max_steps
+        CURRENT_STEPS = state.global_step * CURRENT_GRADIENT_ACCUM
+        MAX_STEPS = state.max_steps * CURRENT_GRADIENT_ACCUM
+        if WANT_INTERRUPT:
+            control.should_epoch_stop = True
+            control.should_training_stop = True
+    def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
+        global CURRENT_STEPS
+        CURRENT_STEPS += 1
+        if WANT_INTERRUPT:
+            control.should_epoch_stop = True
+            control.should_training_stop = True
 
 def cleanPath(basePath: str, path: str):
     """"Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
@@ -59,7 +77,8 @@ def cleanPath(basePath: str, path: str):
     return f'{Path(basePath).absolute()}/{path}'
 
 def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, learningRate: float, loraRank: int, loraAlpha: int, loraDropout: float, cutoffLen: int, dataset: str, evalDataset: str, format: str):
-    global CURRENT_STEPS, MAX_STEPS
+    global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
+    WANT_INTERRUPT = False
     CURRENT_STEPS = 0
     MAX_STEPS = 0
     yield "Prepping..."
@@ -71,6 +90,7 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
     if format is None:
         return "**Missing format choice input, cannot continue.**"
     gradientAccumulationSteps = batchSize // microBatchSize
+    CURRENT_GRADIENT_ACCUM = gradientAccumulationSteps
     actualLR = float(learningRate)
     shared.tokenizer.pad_token = 0
     shared.tokenizer.padding_side = "left"
@@ -161,7 +181,9 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
     startTime = time.perf_counter()
     while thread.is_alive():
         time.sleep(0.5)
-        if CURRENT_STEPS != lastStep:
+        if WANT_INTERRUPT:
+            yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
+        elif CURRENT_STEPS != lastStep:
             lastStep = CURRENT_STEPS
             timeElapsed = time.perf_counter() - startTime
             if timeElapsed <= 0:
@@ -175,5 +197,9 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
             yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.1f}` seconds"
     print("Training complete, saving...")
     loraModel.save_pretrained(loraName)
-    print("Training complete!")
-    yield f"Done! LoRA saved to `{loraName}`"
+    if WANT_INTERRUPT:
+        print("Training interrupted.")
+        yield f"Interrupted. Incomplete LoRA saved to `{loraName}`"
+    else:
+        print("Training complete!")
+        yield f"Done! LoRA saved to `{loraName}`"