training.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import json
  2. import sys
  3. import threading
  4. import time
  5. from pathlib import Path
  6. import gradio as gr
  7. import torch
  8. import transformers
  9. from datasets import Dataset, load_dataset
  10. from peft import (LoraConfig, get_peft_model, get_peft_model_state_dict,
  11. prepare_model_for_int8_training)
  12. from modules import shared, ui
  13. WANT_INTERRUPT = False
  14. CURRENT_STEPS = 0
  15. MAX_STEPS = 0
  16. CURRENT_GRADIENT_ACCUM = 1
  17. def get_dataset(path: str, ext: str):
  18. return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path(path).glob(f'*.{ext}'))), key=str.lower)
  19. def create_train_interface():
  20. with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
  21. lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file")
  22. with gr.Row():
  23. # TODO: Implement multi-device support.
  24. micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
  25. batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
  26. with gr.Row():
  27. epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
  28. learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='Learning rate, in scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
  29. # TODO: What is the actual maximum rank? Likely distinct per model. This might be better to somehow be on a log scale.
  30. lora_rank = gr.Slider(label='LoRA Rank', value=8, minimum=0, maximum=1024, step=4, info='LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model\'s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, high values like 128 or 256 are good for teaching content upgrades. Higher ranks also require higher VRAM.')
  31. lora_alpha = gr.Slider(label='LoRA Alpha', value=16, minimum=0, maximum=2048, step=4, info='LoRA Alpha. This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
  32. # TODO: Better explain what this does, in terms of real world effect especially.
  33. lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers.')
  34. cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
  35. with gr.Tab(label="Formatted Dataset"):
  36. with gr.Row():
  37. dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
  38. ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
  39. eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The dataset file used to evaluate the model after training.')
  40. ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
  41. format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
  42. ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
  43. with gr.Tab(label="Raw Text File"):
  44. with gr.Row():
  45. raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
  46. ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
  47. overlap_len = gr.Slider(label='Overlap Length', minimum=0,maximum=512, value=32, step=8, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length above)')
  48. with gr.Row():
  49. start_button = gr.Button("Start LoRA Training")
  50. stop_button = gr.Button("Interrupt")
  51. output = gr.Markdown(value="Ready")
  52. start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len], [output])
  53. stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
  54. def do_interrupt():
  55. global WANT_INTERRUPT
  56. WANT_INTERRUPT = True
  57. class Callbacks(transformers.TrainerCallback):
  58. def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
  59. global CURRENT_STEPS, MAX_STEPS
  60. CURRENT_STEPS = state.global_step * CURRENT_GRADIENT_ACCUM
  61. MAX_STEPS = state.max_steps * CURRENT_GRADIENT_ACCUM
  62. if WANT_INTERRUPT:
  63. control.should_epoch_stop = True
  64. control.should_training_stop = True
  65. def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
  66. global CURRENT_STEPS
  67. CURRENT_STEPS += 1
  68. if WANT_INTERRUPT:
  69. control.should_epoch_stop = True
  70. control.should_training_stop = True
  71. def clean_path(base_path: str, path: str):
  72. """"Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
  73. # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
  74. # Or swap it to a strict whitelist of [a-zA-Z_0-9]
  75. path = path.replace('\\', '/').replace('..', '_')
  76. if base_path is None:
  77. return path
  78. return f'{Path(base_path).absolute()}/{path}'
  79. def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int,
  80. lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int):
  81. global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
  82. WANT_INTERRUPT = False
  83. CURRENT_STEPS = 0
  84. MAX_STEPS = 0
  85. # == Input validation / processing ==
  86. yield "Prepping..."
  87. lora_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}"
  88. actual_lr = float(learning_rate)
  89. if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
  90. yield f"Cannot input zeroes."
  91. return
  92. gradient_accumulation_steps = batch_size // micro_batch_size
  93. CURRENT_GRADIENT_ACCUM = gradient_accumulation_steps
  94. shared.tokenizer.pad_token = 0
  95. shared.tokenizer.padding_side = "left"
  96. def tokenize(prompt):
  97. result = shared.tokenizer(prompt, truncation=True, max_length=cutoff_len + 1, padding="max_length")
  98. return {
  99. "input_ids": result["input_ids"][:-1],
  100. "attention_mask": result["attention_mask"][:-1],
  101. }
  102. # == Prep the dataset, format, etc ==
  103. if raw_text_file is not None:
  104. print("Loading raw text file dataset...")
  105. with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file:
  106. raw_text = file.read()
  107. tokens = shared.tokenizer.encode(raw_text)
  108. del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
  109. tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
  110. for i in range(1, len(tokens)):
  111. tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
  112. text_chunks = [shared.tokenizer.decode(x) for x in tokens]
  113. del tokens
  114. data = Dataset.from_list([tokenize(x) for x in text_chunks])
  115. train_data = data.shuffle()
  116. eval_data = None
  117. del text_chunks
  118. else:
  119. with open(clean_path('training/formats', f'{format}.json'), 'r') as formatFile:
  120. format_data: dict[str, str] = json.load(formatFile)
  121. if dataset is None:
  122. yield "**Missing dataset choice input, cannot continue.**"
  123. return
  124. if format is None:
  125. yield "**Missing format choice input, cannot continue.**"
  126. return
  127. def generate_prompt(data_point: dict[str, str]):
  128. for options, data in format_data.items():
  129. if set(options.split(',')) == set(x[0] for x in data_point.items() if len(x[1].strip()) > 0):
  130. for key, val in data_point.items():
  131. data = data.replace(f'%{key}%', val)
  132. return data
  133. raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
  134. def generate_and_tokenize_prompt(data_point):
  135. prompt = generate_prompt(data_point)
  136. return tokenize(prompt)
  137. print("Loading JSON datasets...")
  138. data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
  139. train_data = data['train'].shuffle().map(generate_and_tokenize_prompt)
  140. if eval_dataset == 'None':
  141. eval_data = None
  142. else:
  143. eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
  144. eval_data = eval_data['train'].shuffle().map(generate_and_tokenize_prompt)
  145. # == Start prepping the model itself ==
  146. if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
  147. print("Getting model ready...")
  148. prepare_model_for_int8_training(shared.model)
  149. print("Prepping for training...")
  150. config = LoraConfig(
  151. r=lora_rank,
  152. lora_alpha=lora_alpha,
  153. # TODO: Should target_modules be configurable?
  154. target_modules=[ "q_proj", "v_proj" ],
  155. lora_dropout=lora_dropout,
  156. bias="none",
  157. task_type="CAUSAL_LM"
  158. )
  159. lora_model = get_peft_model(shared.model, config)
  160. trainer = transformers.Trainer(
  161. model=lora_model,
  162. train_dataset=train_data,
  163. eval_dataset=eval_data,
  164. args=transformers.TrainingArguments(
  165. per_device_train_batch_size=micro_batch_size,
  166. gradient_accumulation_steps=gradient_accumulation_steps,
  167. # TODO: Should more of these be configurable? Probably.
  168. warmup_steps=100,
  169. num_train_epochs=epochs,
  170. learning_rate=actual_lr,
  171. fp16=True,
  172. logging_steps=20,
  173. evaluation_strategy="steps" if eval_data is not None else "no",
  174. save_strategy="steps",
  175. eval_steps=200 if eval_data is not None else None,
  176. save_steps=200,
  177. output_dir=lora_name,
  178. save_total_limit=3,
  179. load_best_model_at_end=True if eval_data is not None else False,
  180. # TODO: Enable multi-device support
  181. ddp_find_unused_parameters=None
  182. ),
  183. data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
  184. callbacks=list([Callbacks()])
  185. )
  186. lora_model.config.use_cache = False
  187. old_state_dict = lora_model.state_dict
  188. lora_model.state_dict = (
  189. lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
  190. ).__get__(lora_model, type(lora_model))
  191. if torch.__version__ >= "2" and sys.platform != "win32":
  192. lora_model = torch.compile(lora_model)
  193. # == Main run and monitor loop ==
  194. # TODO: save/load checkpoints to resume from?
  195. print("Starting training...")
  196. yield "Starting..."
  197. def threadedRun():
  198. trainer.train()
  199. thread = threading.Thread(target=threadedRun)
  200. thread.start()
  201. lastStep = 0
  202. startTime = time.perf_counter()
  203. while thread.is_alive():
  204. time.sleep(0.5)
  205. if WANT_INTERRUPT:
  206. yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
  207. elif CURRENT_STEPS != lastStep:
  208. lastStep = CURRENT_STEPS
  209. timeElapsed = time.perf_counter() - startTime
  210. if timeElapsed <= 0:
  211. timerInfo = ""
  212. totalTimeEstimate = 999
  213. else:
  214. its = CURRENT_STEPS / timeElapsed
  215. if its > 1:
  216. timerInfo = f"`{its:.2f}` it/s"
  217. else:
  218. timerInfo = f"`{1.0/its:.2f}` s/it"
  219. totalTimeEstimate = (1.0/its) * (MAX_STEPS)
  220. yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds"
  221. print("Training complete, saving...")
  222. lora_model.save_pretrained(lora_name)
  223. if WANT_INTERRUPT:
  224. print("Training interrupted.")
  225. yield f"Interrupted. Incomplete LoRA saved to `{lora_name}`"
  226. else:
  227. print("Training complete!")
  228. yield f"Done! LoRA saved to `{lora_name}`"
  229. def split_chunks(arr, step):
  230. for i in range(0, len(arr), step):
  231. yield arr[i:i + step]