Sfoglia il codice sorgente

fix error from prepare call running twice in a row

Alex "mcmonkey" Goodwin 2 anni fa
parent
commit
5c49a0dcd0
1 ha cambiato i file con 2 aggiunte e 1 eliminazioni
  1. 2 1
      modules/training.py

+ 2 - 1
modules/training.py

@@ -90,7 +90,8 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
         evalData = load_dataset("json", data_files=cleanPath('training/datasets', f'{evalDataset}.json'))
         evalData = evalData['train'].shuffle().map(generate_and_tokenize_prompt)
     # Start prepping the model itself
-    model = prepare_model_for_int8_training(model)
+    if not hasattr(model, 'lm_head') or hasattr(model.lm_head, 'weight'):
+        model = prepare_model_for_int8_training(model)
     config = LoraConfig(
         r=loraRank,
         lora_alpha=loraAlpha,