Prechádzať zdrojové kódy

Fix LoRA device map (attempt)

oobabooga 2 rokov pred
rodič
commit
9bf6ecf9e2
1 zmenil súbory, kde vykonal 7 pridanie a 4 odobranie
  1. 7 4
      modules/LoRA.py

+ 7 - 4
modules/LoRA.py

@@ -16,12 +16,15 @@ def add_lora_to_model(lora_name):
         print(f"Adding the LoRA {lora_name} to the model...")
         
         params = {}
-        if shared.args.load_in_8bit:
-            params['device_map'] = {'': 0}
-        elif not shared.args.cpu:
-            params['device_map'] = 'auto' 
+        if not shared.args.cpu:
             params['dtype'] = shared.model.dtype
+            if hasattr(shared.model, "hf_device_map"):
+                params['device_map'] = {"base_model.model."+k: v for k, v in shared.model.hf_device_map.items()}
+            elif shared.args.load_in_8bit:
+                params['device_map'] = {'': 0}
             
         shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params)
         if not shared.args.load_in_8bit and not shared.args.cpu:
             shared.model.half()
+            if not hasattr(shared.model, "hf_device_map"):
+                shared.model.cuda()