GPTQ_loader.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import inspect
  2. import re
  3. import sys
  4. from pathlib import Path
  5. import accelerate
  6. import torch
  7. import transformers
  8. from transformers import AutoConfig, AutoModelForCausalLM
  9. import modules.shared as shared
  10. sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
  11. import llama_inference_offload
  12. from modelutils import find_layers
  13. from quant import make_quant
  14. def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
  15. def noop(*args, **kwargs):
  16. pass
  17. config = AutoConfig.from_pretrained(model)
  18. torch.nn.init.kaiming_uniform_ = noop
  19. torch.nn.init.uniform_ = noop
  20. torch.nn.init.normal_ = noop
  21. torch.set_default_dtype(torch.half)
  22. transformers.modeling_utils._init_weights = False
  23. torch.set_default_dtype(torch.half)
  24. model = AutoModelForCausalLM.from_config(config)
  25. torch.set_default_dtype(torch.float)
  26. model = model.eval()
  27. layers = find_layers(model)
  28. for name in exclude_layers:
  29. if name in layers:
  30. del layers[name]
  31. gptq_args = inspect.getfullargspec(make_quant).args
  32. make_quant_kwargs = {
  33. 'module': model,
  34. 'names': layers,
  35. 'bits': wbits,
  36. }
  37. if 'groupsize' in gptq_args:
  38. make_quant_kwargs['groupsize'] = groupsize
  39. if 'faster' in gptq_args:
  40. make_quant_kwargs['faster'] = faster_kernel
  41. if 'kernel_switch_threshold' in gptq_args:
  42. make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
  43. make_quant(**make_quant_kwargs)
  44. del layers
  45. print('Loading model ...')
  46. if checkpoint.endswith('.safetensors'):
  47. from safetensors.torch import load_file as safe_load
  48. model.load_state_dict(safe_load(checkpoint), strict=False)
  49. else:
  50. model.load_state_dict(torch.load(checkpoint), strict=False)
  51. model.seqlen = 2048
  52. print('Done.')
  53. return model
  54. def load_quantized(model_name):
  55. if not shared.args.model_type:
  56. # Try to determine model type from model name
  57. name = model_name.lower()
  58. if any((k in name for k in ['llama', 'alpaca', 'vicuna'])):
  59. model_type = 'llama'
  60. elif any((k in name for k in ['opt-', 'galactica'])):
  61. model_type = 'opt'
  62. elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])):
  63. model_type = 'gptj'
  64. else:
  65. print("Can't determine model type from model name. Please specify it manually using --model_type "
  66. "argument")
  67. exit()
  68. else:
  69. model_type = shared.args.model_type.lower()
  70. if shared.args.pre_layer and model_type == 'llama':
  71. load_quant = llama_inference_offload.load_quant
  72. elif model_type in ('llama', 'opt', 'gptj'):
  73. if shared.args.pre_layer:
  74. print("Warning: ignoring --pre_layer because it only works for llama model type.")
  75. load_quant = _load_quant
  76. else:
  77. print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
  78. exit()
  79. # Now we are going to try to locate the quantized model file.
  80. path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
  81. found_pts = list(path_to_model.glob("*.pt"))
  82. found_safetensors = list(path_to_model.glob("*.safetensors"))
  83. pt_path = None
  84. if len(found_pts) > 0:
  85. pt_path = found_pts[-1]
  86. elif len(found_safetensors) > 0:
  87. pt_path = found_safetensors[-1]
  88. else:
  89. if path_to_model.name.lower().startswith('llama-7b'):
  90. pt_model = f'llama-7b-{shared.args.wbits}bit'
  91. elif path_to_model.name.lower().startswith('llama-13b'):
  92. pt_model = f'llama-13b-{shared.args.wbits}bit'
  93. elif path_to_model.name.lower().startswith('llama-30b'):
  94. pt_model = f'llama-30b-{shared.args.wbits}bit'
  95. elif path_to_model.name.lower().startswith('llama-65b'):
  96. pt_model = f'llama-65b-{shared.args.wbits}bit'
  97. else:
  98. pt_model = f'{model_name}-{shared.args.wbits}bit'
  99. # Try to find the .safetensors or .pt both in the model dir and in the subfolder
  100. for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
  101. if path.exists():
  102. pt_path = path
  103. break
  104. if not pt_path:
  105. print("Could not find the quantized model in .pt or .safetensors format, exiting...")
  106. exit()
  107. else:
  108. print(f"Found the following quantized model: {pt_path}")
  109. # qwopqwop200's offload
  110. if model_type == 'llama' and shared.args.pre_layer:
  111. model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
  112. else:
  113. threshold = False if model_type == 'gptj' else 128
  114. model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
  115. # accelerate offload (doesn't work properly)
  116. if shared.args.gpu_memory:
  117. memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
  118. max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
  119. max_memory = {}
  120. for i in range(len(memory_map)):
  121. max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
  122. max_memory['cpu'] = max_cpu_memory
  123. device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
  124. print("Using the following device map for the 4-bit model:", device_map)
  125. # https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
  126. model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
  127. # No offload
  128. elif not shared.args.cpu:
  129. model = model.to(torch.device('cuda:0'))
  130. return model