Ver Fonte

Don't split the layers in 8-bit mode by default

oobabooga há 2 anos atrás
pai
commit
ee164d1821
1 ficheiros alterados com 4 adições e 2 exclusões
  1. 4 2
      modules/models.py

+ 4 - 2
modules/models.py

@@ -105,8 +105,10 @@ def load_model(model_name):
             params["torch_dtype"] = torch.float32
             params["torch_dtype"] = torch.float32
         else:
         else:
             params["device_map"] = 'auto'
             params["device_map"] = 'auto'
-            if shared.args.load_in_8bit:
+            if shared.args.load_in_8bit and any((shared.args.auto_devices, shared.args.gpu_memory)):
                 params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
                 params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
+            elif shared.args.load_in_8bit:
+                params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
             elif shared.args.bf16:
             elif shared.args.bf16:
                 params["torch_dtype"] = torch.bfloat16
                 params["torch_dtype"] = torch.bfloat16
             else:
             else:
@@ -119,7 +121,7 @@ def load_model(model_name):
                     max_memory[i] = f'{memory_map[i]}GiB'
                     max_memory[i] = f'{memory_map[i]}GiB'
                 max_memory['cpu'] = f'{shared.args.cpu_memory or 99}GiB'
                 max_memory['cpu'] = f'{shared.args.cpu_memory or 99}GiB'
                 params['max_memory'] = max_memory
                 params['max_memory'] = max_memory
-            else:
+            elif shared.args.auto_devices:
                 total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024*1024))
                 total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024*1024))
                 suggestion = round((total_mem-1000) / 1000) * 1000
                 suggestion = round((total_mem-1000) / 1000) * 1000
                 if total_mem - suggestion < 800:
                 if total_mem - suggestion < 800: