Explorar o código

Add support for latest cuda branch

oobabooga %!s(int64=2) %!d(string=hai) anos
pai
achega
8781c84287
Modificáronse 1 ficheiros con 9 adicións e 10 borrados
  1. 9 10
      modules/GPTQ_loader.py

+ 9 - 10
modules/GPTQ_loader.py

@@ -15,13 +15,13 @@ from modelutils import find_layers
 from quant import make_quant
 
 
-def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
+def _load_quant(model, checkpoint, wbits, groupsize=-1, exclude_layers=['lm_head']):
     config = AutoConfig.from_pretrained(model)
     def noop(*args, **kwargs):
         pass
-    torch.nn.init.kaiming_uniform_ = noop 
-    torch.nn.init.uniform_ = noop 
-    torch.nn.init.normal_ = noop 
+    torch.nn.init.kaiming_uniform_ = noop
+    torch.nn.init.uniform_ = noop
+    torch.nn.init.normal_ = noop
 
     torch.set_default_dtype(torch.half)
     transformers.modeling_utils._init_weights = False
@@ -33,16 +33,16 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
     for name in exclude_layers:
         if name in layers:
             del layers[name]
-    make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
+    make_quant(model, layers, wbits, groupsize)
 
     del layers
-    
+
     print('Loading model ...')
     if checkpoint.endswith('.safetensors'):
         from safetensors.torch import load_file as safe_load
-        model.load_state_dict(safe_load(checkpoint))
+        model.load_state_dict(safe_load(checkpoint), strict = False)
     else:
-        model.load_state_dict(torch.load(checkpoint))
+        model.load_state_dict(torch.load(checkpoint), strict = False)
     model.seqlen = 2048
     print('Done.')
 
@@ -110,8 +110,7 @@ def load_quantized(model_name):
     if shared.args.pre_layer:
         model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
     else:
-        threshold = False if model_type == 'gptj' else 128
-        model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
+        model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize)
 
         # accelerate offload (doesn't work properly)
         if shared.args.gpu_memory: