Ver código fonte

Broaden GPTQ-for-LLaMA branch support (#820)

EyeDeck 2 anos atrás
pai
commit
39f3fec913
1 arquivos alterados com 23 adições e 7 exclusões
  1. 23 7
      modules/GPTQ_loader.py

+ 23 - 7
modules/GPTQ_loader.py

@@ -1,3 +1,4 @@
+import inspect
 import re
 import sys
 from pathlib import Path
@@ -19,9 +20,9 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
     config = AutoConfig.from_pretrained(model)
     def noop(*args, **kwargs):
         pass
-    torch.nn.init.kaiming_uniform_ = noop 
-    torch.nn.init.uniform_ = noop 
-    torch.nn.init.normal_ = noop 
+    torch.nn.init.kaiming_uniform_ = noop
+    torch.nn.init.uniform_ = noop
+    torch.nn.init.normal_ = noop
 
     torch.set_default_dtype(torch.half)
     transformers.modeling_utils._init_weights = False
@@ -33,16 +34,31 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
     for name in exclude_layers:
         if name in layers:
             del layers[name]
-    make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
+    
+    gptq_args = inspect.getfullargspec(make_quant).args
+
+    make_quant_kwargs = {
+        'module': model, 
+        'names': layers,
+        'bits': wbits,
+    }
+    if 'groupsize' in gptq_args:
+        make_quant_kwargs['groupsize'] = groupsize
+    if 'faster' in gptq_args:
+        make_quant_kwargs['faster'] = faster_kernel
+    if 'kernel_switch_threshold' in gptq_args:
+        make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
+    
+    make_quant(**make_quant_kwargs)
 
     del layers
-    
+
     print('Loading model ...')
     if checkpoint.endswith('.safetensors'):
         from safetensors.torch import load_file as safe_load
-        model.load_state_dict(safe_load(checkpoint))
+        model.load_state_dict(safe_load(checkpoint), strict = False)
     else:
-        model.load_state_dict(torch.load(checkpoint))
+        model.load_state_dict(torch.load(checkpoint), strict = False)
     model.seqlen = 2048
     print('Done.')