|
@@ -1,3 +1,4 @@
|
|
|
|
|
+import inspect
|
|
|
import re
|
|
import re
|
|
|
import sys
|
|
import sys
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
@@ -19,9 +20,9 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|
|
config = AutoConfig.from_pretrained(model)
|
|
config = AutoConfig.from_pretrained(model)
|
|
|
def noop(*args, **kwargs):
|
|
def noop(*args, **kwargs):
|
|
|
pass
|
|
pass
|
|
|
- torch.nn.init.kaiming_uniform_ = noop
|
|
|
|
|
- torch.nn.init.uniform_ = noop
|
|
|
|
|
- torch.nn.init.normal_ = noop
|
|
|
|
|
|
|
+ torch.nn.init.kaiming_uniform_ = noop
|
|
|
|
|
+ torch.nn.init.uniform_ = noop
|
|
|
|
|
+ torch.nn.init.normal_ = noop
|
|
|
|
|
|
|
|
torch.set_default_dtype(torch.half)
|
|
torch.set_default_dtype(torch.half)
|
|
|
transformers.modeling_utils._init_weights = False
|
|
transformers.modeling_utils._init_weights = False
|
|
@@ -33,16 +34,31 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|
|
for name in exclude_layers:
|
|
for name in exclude_layers:
|
|
|
if name in layers:
|
|
if name in layers:
|
|
|
del layers[name]
|
|
del layers[name]
|
|
|
- make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
|
|
|
|
|
|
|
+
|
|
|
|
|
+ gptq_args = inspect.getfullargspec(make_quant).args
|
|
|
|
|
+
|
|
|
|
|
+ make_quant_kwargs = {
|
|
|
|
|
+ 'module': model,
|
|
|
|
|
+ 'names': layers,
|
|
|
|
|
+ 'bits': wbits,
|
|
|
|
|
+ }
|
|
|
|
|
+ if 'groupsize' in gptq_args:
|
|
|
|
|
+ make_quant_kwargs['groupsize'] = groupsize
|
|
|
|
|
+ if 'faster' in gptq_args:
|
|
|
|
|
+ make_quant_kwargs['faster'] = faster_kernel
|
|
|
|
|
+ if 'kernel_switch_threshold' in gptq_args:
|
|
|
|
|
+ make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
|
|
|
|
|
+
|
|
|
|
|
+ make_quant(**make_quant_kwargs)
|
|
|
|
|
|
|
|
del layers
|
|
del layers
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
print('Loading model ...')
|
|
print('Loading model ...')
|
|
|
if checkpoint.endswith('.safetensors'):
|
|
if checkpoint.endswith('.safetensors'):
|
|
|
from safetensors.torch import load_file as safe_load
|
|
from safetensors.torch import load_file as safe_load
|
|
|
- model.load_state_dict(safe_load(checkpoint))
|
|
|
|
|
|
|
+ model.load_state_dict(safe_load(checkpoint), strict = False)
|
|
|
else:
|
|
else:
|
|
|
- model.load_state_dict(torch.load(checkpoint))
|
|
|
|
|
|
|
+ model.load_state_dict(torch.load(checkpoint), strict = False)
|
|
|
model.seqlen = 2048
|
|
model.seqlen = 2048
|
|
|
print('Done.')
|
|
print('Done.')
|
|
|
|
|
|