1 Commits

Author SHA1 Message Date
oobabooga
8781c84287 Add support for latest cuda branch 2023-04-05 00:09:53 -03:00

View File

@@ -15,13 +15,13 @@ from modelutils import find_layers
from quant import make_quant from quant import make_quant
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128): def _load_quant(model, checkpoint, wbits, groupsize=-1, exclude_layers=['lm_head']):
config = AutoConfig.from_pretrained(model) config = AutoConfig.from_pretrained(model)
def noop(*args, **kwargs): def noop(*args, **kwargs):
pass pass
torch.nn.init.kaiming_uniform_ = noop torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half) torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False transformers.modeling_utils._init_weights = False
@@ -33,16 +33,16 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
for name in exclude_layers: for name in exclude_layers:
if name in layers: if name in layers:
del layers[name] del layers[name]
make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold) make_quant(model, layers, wbits, groupsize)
del layers del layers
print('Loading model ...') print('Loading model ...')
if checkpoint.endswith('.safetensors'): if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint)) model.load_state_dict(safe_load(checkpoint), strict = False)
else: else:
model.load_state_dict(torch.load(checkpoint)) model.load_state_dict(torch.load(checkpoint), strict = False)
model.seqlen = 2048 model.seqlen = 2048
print('Done.') print('Done.')
@@ -110,8 +110,7 @@ def load_quantized(model_name):
if shared.args.pre_layer: if shared.args.pre_layer:
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer) model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
else: else:
threshold = False if model_type == 'gptj' else 128 model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize)
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
# accelerate offload (doesn't work properly) # accelerate offload (doesn't work properly)
if shared.args.gpu_memory: if shared.args.gpu_memory: