Browse Source

Small style changes

oobabooga 2 năm trước cách đây
mục cha
commit
2f0571bfa4
3 tập tin đã thay đổi với 20 bổ sung7 xóa
  1. 1 1
      css/main.css
  2. 18 5
      modules/training.py
  3. 1 1
      server.py

+ 1 - 1
css/main.css

@@ -41,7 +41,7 @@ ol li p, ul li p {
     display: inline-block;
 }
 
-#main, #parameters, #chat-settings, #interface-mode, #lora {
+#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab {
   border: 0;
 }
 

+ 18 - 5
modules/training.py

@@ -1,10 +1,17 @@
-import sys, torch, json, threading, time
+import json
+import sys
+import threading
+import time
 from pathlib import Path
+
 import gradio as gr
-from datasets import load_dataset
+import torch
 import transformers
-from modules import ui, shared
-from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict
+from datasets import load_dataset
+from peft import (LoraConfig, get_peft_model, get_peft_model_state_dict,
+                  prepare_model_for_int8_training)
+
+from modules import shared, ui
 
 WANT_INTERRUPT = False
 CURRENT_STEPS = 0
@@ -44,7 +51,7 @@ def create_train_interface():
         with gr.Row():
             startButton = gr.Button("Start LoRA Training")
             stopButton = gr.Button("Interrupt")
-        output = gr.Markdown(value="(...)")
+        output = gr.Markdown(value="Ready")
         startEvent = startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output])
         stopButton.click(doInterrupt, [], [], cancels=[], queue=False)
 
@@ -169,16 +176,20 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
     ).__get__(loraModel, type(loraModel))
     if torch.__version__ >= "2" and sys.platform != "win32":
         loraModel = torch.compile(loraModel)
+
     # == Main run and monitor loop ==
     # TODO: save/load checkpoints to resume from?
     print("Starting training...")
     yield "Starting..."
+
     def threadedRun():
         trainer.train()
+
     thread = threading.Thread(target=threadedRun)
     thread.start()
     lastStep = 0
     startTime = time.perf_counter()
+
     while thread.is_alive():
         time.sleep(0.5)
         if WANT_INTERRUPT:
@@ -197,8 +208,10 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
                     timerInfo = f"`{1.0/its:.2f}` s/it"
                 totalTimeEstimate = (1.0/its) * (MAX_STEPS)
             yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds"
+
     print("Training complete, saving...")
     loraModel.save_pretrained(loraName)
+
     if WANT_INTERRUPT:
         print("Training interrupted.")
         yield f"Interrupted. Incomplete LoRA saved to `{loraName}`"

+ 1 - 1
server.py

@@ -9,8 +9,8 @@ from pathlib import Path
 
 import gradio as gr
 
-from modules import chat, shared, ui, training
 import modules.extensions as extensions_module
+from modules import chat, shared, training, ui
 from modules.html_generator import generate_chat_html
 from modules.LoRA import add_lora_to_model
 from modules.models import load_model, load_soft_prompt