Explorar el Código

fix error from prepare call running twice in a row

Alex "mcmonkey" Goodwin hace 2 años
padre
commit
5c49a0dcd0
Se han modificado 1 ficheros con 2 adiciones y 1 borrados
  1. 2 1
      modules/training.py

+ 2 - 1
modules/training.py

@@ -90,7 +90,8 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
         evalData = load_dataset("json", data_files=cleanPath('training/datasets', f'{evalDataset}.json'))
         evalData = evalData['train'].shuffle().map(generate_and_tokenize_prompt)
     # Start prepping the model itself
-    model = prepare_model_for_int8_training(model)
+    if not hasattr(model, 'lm_head') or hasattr(model.lm_head, 'weight'):
+        model = prepare_model_for_int8_training(model)
     config = LoraConfig(
         r=loraRank,
         lora_alpha=loraAlpha,