GPTQ_loader.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import re
  2. import sys
  3. from pathlib import Path
  4. import accelerate
  5. import torch
  6. import transformers
  7. from transformers import AutoConfig, AutoModelForCausalLM
  8. import modules.shared as shared
  9. sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
  10. import llama_inference_offload
  11. from modelutils import find_layers
  12. from quant import make_quant
  13. def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
  14. config = AutoConfig.from_pretrained(model)
  15. def noop(*args, **kwargs):
  16. pass
  17. torch.nn.init.kaiming_uniform_ = noop
  18. torch.nn.init.uniform_ = noop
  19. torch.nn.init.normal_ = noop
  20. torch.set_default_dtype(torch.half)
  21. transformers.modeling_utils._init_weights = False
  22. torch.set_default_dtype(torch.half)
  23. model = AutoModelForCausalLM.from_config(config)
  24. torch.set_default_dtype(torch.float)
  25. model = model.eval()
  26. layers = find_layers(model)
  27. for name in exclude_layers:
  28. if name in layers:
  29. del layers[name]
  30. make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
  31. del layers
  32. print('Loading model ...')
  33. if checkpoint.endswith('.safetensors'):
  34. from safetensors.torch import load_file as safe_load
  35. model.load_state_dict(safe_load(checkpoint))
  36. else:
  37. model.load_state_dict(torch.load(checkpoint))
  38. model.seqlen = 2048
  39. print('Done.')
  40. return model
  41. def load_quantized(model_name):
  42. if not shared.args.model_type:
  43. # Try to determine model type from model name
  44. name = model_name.lower()
  45. if any((k in name for k in ['llama', 'alpaca', 'vicuna'])):
  46. model_type = 'llama'
  47. elif any((k in name for k in ['opt-', 'galactica'])):
  48. model_type = 'opt'
  49. elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])):
  50. model_type = 'gptj'
  51. else:
  52. print("Can't determine model type from model name. Please specify it manually using --model_type "
  53. "argument")
  54. exit()
  55. else:
  56. model_type = shared.args.model_type.lower()
  57. if shared.args.pre_layer and model_type == 'llama':
  58. load_quant = llama_inference_offload.load_quant
  59. elif model_type in ('llama', 'opt', 'gptj'):
  60. if shared.args.pre_layer:
  61. print("Warning: ignoring --pre_layer because it only works for llama model type.")
  62. load_quant = _load_quant
  63. else:
  64. print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
  65. exit()
  66. # Now we are going to try to locate the quantized model file.
  67. path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
  68. found_pts = list(path_to_model.glob("*.pt"))
  69. found_safetensors = list(path_to_model.glob("*.safetensors"))
  70. pt_path = None
  71. if len(found_pts) == 1:
  72. pt_path = found_pts[0]
  73. elif len(found_safetensors) == 1:
  74. pt_path = found_safetensors[0]
  75. else:
  76. if path_to_model.name.lower().startswith('llama-7b'):
  77. pt_model = f'llama-7b-{shared.args.wbits}bit'
  78. elif path_to_model.name.lower().startswith('llama-13b'):
  79. pt_model = f'llama-13b-{shared.args.wbits}bit'
  80. elif path_to_model.name.lower().startswith('llama-30b'):
  81. pt_model = f'llama-30b-{shared.args.wbits}bit'
  82. elif path_to_model.name.lower().startswith('llama-65b'):
  83. pt_model = f'llama-65b-{shared.args.wbits}bit'
  84. else:
  85. pt_model = f'{model_name}-{shared.args.wbits}bit'
  86. # Try to find the .safetensors or .pt both in the model dir and in the subfolder
  87. for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
  88. if path.exists():
  89. print(f"Found {path}")
  90. pt_path = path
  91. break
  92. if not pt_path:
  93. print("Could not find the quantized model in .pt or .safetensors format, exiting...")
  94. exit()
  95. # qwopqwop200's offload
  96. if model_type == 'llama' and shared.args.pre_layer:
  97. model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
  98. else:
  99. threshold = False if model_type == 'gptj' else 128
  100. model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
  101. # accelerate offload (doesn't work properly)
  102. if shared.args.gpu_memory:
  103. memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
  104. max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
  105. max_memory = {}
  106. for i in range(len(memory_map)):
  107. max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
  108. max_memory['cpu'] = max_cpu_memory
  109. device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
  110. print("Using the following device map for the 4-bit model:", device_map)
  111. # https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
  112. model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
  113. # No offload
  114. elif not shared.args.cpu:
  115. model = model.to(torch.device('cuda:0'))
  116. return model