Просмотр исходного кода

Make LoRAs work in 16-bit mode

oobabooga 2 лет назад
Родитель
Сommit
eac27f4f55
1 измененных файлов с 9 добавлено и 4 удалено
  1. 9 4
      modules/LoRA.py

+ 9 - 4
modules/LoRA.py

@@ -13,10 +13,15 @@ def add_lora_to_model(lora_name):
         print("Reloading the model to remove the LoRA...")
         shared.model, shared.tokenizer = load_model(shared.model_name)
     else:
-        # Why doesn't this work in 16-bit mode?
         print(f"Adding the LoRA {lora_name} to the model...")
-
+        
         params = {}
-        params['device_map'] = {'': 0}
-        #params['dtype'] = shared.model.dtype
+        if shared.args.load_in_8bit:
+            params['device_map'] = {'': 0}
+        else:
+            params['device_map'] = 'auto' 
+            params['dtype'] = shared.model.dtype
+            
         shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params)
+        if not shared.args.load_in_8bit:
+            shared.model.half()