download-model.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. '''
  2. Downloads models from Hugging Face to models/model-name.
  3. Example:
  4. python download-model.py facebook/opt-1.3b
  5. '''
  6. import argparse
  7. import base64
  8. import datetime
  9. import hashlib
  10. import json
  11. import re
  12. import sys
  13. from pathlib import Path
  14. import requests
  15. import tqdm
  16. from tqdm.contrib.concurrent import thread_map
  17. parser = argparse.ArgumentParser()
  18. parser.add_argument('MODEL', type=str, default=None, nargs='?')
  19. parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
  20. parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
  21. parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
  22. parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
  23. parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
  24. parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
  25. args = parser.parse_args()
  26. def get_file(url, output_folder):
  27. filename = Path(url.rsplit('/', 1)[1])
  28. output_path = output_folder / filename
  29. if output_path.exists() and not args.clean:
  30. # Check if the file has already been downloaded completely
  31. r = requests.get(url, stream=True)
  32. total_size = int(r.headers.get('content-length', 0))
  33. if output_path.stat().st_size >= total_size:
  34. return
  35. # Otherwise, resume the download from where it left off
  36. headers = {'Range': f'bytes={output_path.stat().st_size}-'}
  37. mode = 'ab'
  38. else:
  39. headers = {}
  40. mode = 'wb'
  41. r = requests.get(url, stream=True, headers=headers)
  42. with open(output_path, mode) as f:
  43. total_size = int(r.headers.get('content-length', 0))
  44. block_size = 1024
  45. with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
  46. for data in r.iter_content(block_size):
  47. t.update(len(data))
  48. f.write(data)
  49. def sanitize_branch_name(branch_name):
  50. pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
  51. if pattern.match(branch_name):
  52. return branch_name
  53. else:
  54. raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
  55. def select_model_from_default_options():
  56. models = {
  57. "OPT 6.7B": ("facebook", "opt-6.7b", "main"),
  58. "OPT 2.7B": ("facebook", "opt-2.7b", "main"),
  59. "OPT 1.3B": ("facebook", "opt-1.3b", "main"),
  60. "OPT 350M": ("facebook", "opt-350m", "main"),
  61. "GALACTICA 6.7B": ("facebook", "galactica-6.7b", "main"),
  62. "GALACTICA 1.3B": ("facebook", "galactica-1.3b", "main"),
  63. "GALACTICA 125M": ("facebook", "galactica-125m", "main"),
  64. "Pythia-6.9B-deduped": ("EleutherAI", "pythia-6.9b-deduped", "main"),
  65. "Pythia-2.8B-deduped": ("EleutherAI", "pythia-2.8b-deduped", "main"),
  66. "Pythia-1.4B-deduped": ("EleutherAI", "pythia-1.4b-deduped", "main"),
  67. "Pythia-410M-deduped": ("EleutherAI", "pythia-410m-deduped", "main"),
  68. }
  69. choices = {}
  70. print("Select the model that you want to download:\n")
  71. for i, name in enumerate(models):
  72. char = chr(ord('A') + i)
  73. choices[char] = name
  74. print(f"{char}) {name}")
  75. char = chr(ord('A') + len(models))
  76. print(f"{char}) None of the above")
  77. print()
  78. print("Input> ", end='')
  79. choice = input()[0].strip().upper()
  80. if choice == char:
  81. print("""\nThen type the name of your desired Hugging Face model in the format organization/name.
  82. Examples:
  83. facebook/opt-1.3b
  84. EleutherAI/pythia-1.4b-deduped
  85. """)
  86. print("Input> ", end='')
  87. model = input()
  88. branch = "main"
  89. else:
  90. arr = models[choices[choice]]
  91. model = f"{arr[0]}/{arr[1]}"
  92. branch = arr[2]
  93. return model, branch
  94. def get_download_links_from_huggingface(model, branch):
  95. base = "https://huggingface.co"
  96. page = f"/api/models/{model}/tree/{branch}?cursor="
  97. cursor = b""
  98. links = []
  99. sha256 = []
  100. classifications = []
  101. has_pytorch = False
  102. has_pt = False
  103. has_ggml = False
  104. has_safetensors = False
  105. is_lora = False
  106. while True:
  107. content = requests.get(f"{base}{page}{cursor.decode()}").content
  108. dict = json.loads(content)
  109. if len(dict) == 0:
  110. break
  111. for i in range(len(dict)):
  112. fname = dict[i]['path']
  113. if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
  114. is_lora = True
  115. is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
  116. is_safetensors = re.match(".*\.safetensors", fname)
  117. is_pt = re.match(".*\.pt", fname)
  118. is_ggml = re.match("ggml.*\.bin", fname)
  119. is_tokenizer = re.match("tokenizer.*\.model", fname)
  120. is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
  121. if any((is_pytorch, is_safetensors, is_pt, is_tokenizer, is_text)):
  122. if 'lfs' in dict[i]:
  123. sha256.append([fname, dict[i]['lfs']['oid']])
  124. if is_text:
  125. links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
  126. classifications.append('text')
  127. continue
  128. if not args.text_only:
  129. links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
  130. if is_safetensors:
  131. has_safetensors = True
  132. classifications.append('safetensors')
  133. elif is_pytorch:
  134. has_pytorch = True
  135. classifications.append('pytorch')
  136. elif is_pt:
  137. has_pt = True
  138. classifications.append('pt')
  139. elif is_ggml:
  140. has_ggml = True
  141. classifications.append('ggml')
  142. cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
  143. cursor = base64.b64encode(cursor)
  144. cursor = cursor.replace(b'=', b'%3D')
  145. # If both pytorch and safetensors are available, download safetensors only
  146. if (has_pytorch or has_pt) and has_safetensors:
  147. for i in range(len(classifications) - 1, -1, -1):
  148. if classifications[i] in ['pytorch', 'pt']:
  149. links.pop(i)
  150. return links, sha256, is_lora
  151. def download_files(file_list, output_folder, num_threads=8):
  152. thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
  153. if __name__ == '__main__':
  154. model = args.MODEL
  155. branch = args.branch
  156. if model is None:
  157. model, branch = select_model_from_default_options()
  158. else:
  159. if model[-1] == '/':
  160. model = model[:-1]
  161. branch = args.branch
  162. if branch is None:
  163. branch = "main"
  164. else:
  165. try:
  166. branch = sanitize_branch_name(branch)
  167. except ValueError as err_branch:
  168. print(f"Error: {err_branch}")
  169. sys.exit()
  170. links, sha256, is_lora = get_download_links_from_huggingface(model, branch)
  171. if args.output is not None:
  172. base_folder = args.output
  173. else:
  174. base_folder = 'models' if not is_lora else 'loras'
  175. output_folder = f"{'_'.join(model.split('/')[-2:])}"
  176. if branch != 'main':
  177. output_folder += f'_{branch}'
  178. output_folder = Path(base_folder) / output_folder
  179. if args.check:
  180. # Validate the checksums
  181. validated = True
  182. for i in range(len(sha256)):
  183. fpath = (output_folder / sha256[i][0])
  184. if not fpath.exists():
  185. print(f"The following file is missing: {fpath}")
  186. validated = False
  187. continue
  188. with open(output_folder / sha256[i][0], "rb") as f:
  189. bytes = f.read()
  190. file_hash = hashlib.sha256(bytes).hexdigest()
  191. if file_hash != sha256[i][1]:
  192. print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}')
  193. validated = False
  194. else:
  195. print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}')
  196. if validated:
  197. print('[+] Validated checksums of all model files!')
  198. else:
  199. print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
  200. else:
  201. # Creating the folder and writing the metadata
  202. if not output_folder.exists():
  203. output_folder.mkdir()
  204. with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
  205. f.write(f'url: https://huggingface.co/{model}\n')
  206. f.write(f'branch: {branch}\n')
  207. f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
  208. sha256_str = ''
  209. for i in range(len(sha256)):
  210. sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
  211. if sha256_str != '':
  212. f.write(f'sha256sum:\n{sha256_str}')
  213. # Downloading the files
  214. print(f"Downloading the model to {output_folder}")
  215. download_files(links, output_folder, args.threads)