Explorar el Código

Fix LoRA in CPU mode

oobabooga hace 2 años
padre
commit
29bd41d453
Se han modificado 1 ficheros con 2 adiciones y 2 borrados
  1. 2 2
      modules/LoRA.py

+ 2 - 2
modules/LoRA.py

@@ -18,10 +18,10 @@ def add_lora_to_model(lora_name):
         params = {}
         if shared.args.load_in_8bit:
             params['device_map'] = {'': 0}
-        else:
+        elif not shared.args.cpu:
             params['device_map'] = 'auto' 
             params['dtype'] = shared.model.dtype
             
         shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params)
-        if not shared.args.load_in_8bit:
+        if not shared.args.load_in_8bit and not shared.args.cpu:
             shared.model.half()