download-model.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. '''
  2. Downloads models from Hugging Face to models/model-name.
  3. Example:
  4. python download-model.py facebook/opt-1.3b
  5. '''
  6. import argparse
  7. import base64
  8. import datetime
  9. import hashlib
  10. import json
  11. import re
  12. import sys
  13. from pathlib import Path
  14. import requests
  15. import tqdm
  16. from tqdm.contrib.concurrent import thread_map
  17. parser = argparse.ArgumentParser()
  18. parser.add_argument('MODEL', type=str, default=None, nargs='?')
  19. parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
  20. parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
  21. parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
  22. parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
  23. parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
  24. parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
  25. args = parser.parse_args()
  26. def select_model_from_default_options():
  27. models = {
  28. "OPT 6.7B": ("facebook", "opt-6.7b", "main"),
  29. "OPT 2.7B": ("facebook", "opt-2.7b", "main"),
  30. "OPT 1.3B": ("facebook", "opt-1.3b", "main"),
  31. "OPT 350M": ("facebook", "opt-350m", "main"),
  32. "GALACTICA 6.7B": ("facebook", "galactica-6.7b", "main"),
  33. "GALACTICA 1.3B": ("facebook", "galactica-1.3b", "main"),
  34. "GALACTICA 125M": ("facebook", "galactica-125m", "main"),
  35. "Pythia-6.9B-deduped": ("EleutherAI", "pythia-6.9b-deduped", "main"),
  36. "Pythia-2.8B-deduped": ("EleutherAI", "pythia-2.8b-deduped", "main"),
  37. "Pythia-1.4B-deduped": ("EleutherAI", "pythia-1.4b-deduped", "main"),
  38. "Pythia-410M-deduped": ("EleutherAI", "pythia-410m-deduped", "main"),
  39. }
  40. choices = {}
  41. print("Select the model that you want to download:\n")
  42. for i, name in enumerate(models):
  43. char = chr(ord('A') + i)
  44. choices[char] = name
  45. print(f"{char}) {name}")
  46. char = chr(ord('A') + len(models))
  47. print(f"{char}) None of the above")
  48. print()
  49. print("Input> ", end='')
  50. choice = input()[0].strip().upper()
  51. if choice == char:
  52. print("""\nThen type the name of your desired Hugging Face model in the format organization/name.
  53. Examples:
  54. facebook/opt-1.3b
  55. EleutherAI/pythia-1.4b-deduped
  56. """)
  57. print("Input> ", end='')
  58. model = input()
  59. branch = "main"
  60. else:
  61. arr = models[choices[choice]]
  62. model = f"{arr[0]}/{arr[1]}"
  63. branch = arr[2]
  64. return model, branch
  65. def sanitize_model_and_branch_names(model, branch):
  66. if model[-1] == '/':
  67. model = model[:-1]
  68. if branch is None:
  69. branch = "main"
  70. else:
  71. pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
  72. if not pattern.match(branch):
  73. raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
  74. return model, branch
  75. def get_download_links_from_huggingface(model, branch, text_only=False):
  76. base = "https://huggingface.co"
  77. page = f"/api/models/{model}/tree/{branch}?cursor="
  78. cursor = b""
  79. links = []
  80. sha256 = []
  81. classifications = []
  82. has_pytorch = False
  83. has_pt = False
  84. has_ggml = False
  85. has_safetensors = False
  86. is_lora = False
  87. while True:
  88. content = requests.get(f"{base}{page}{cursor.decode()}").content
  89. dict = json.loads(content)
  90. if len(dict) == 0:
  91. break
  92. for i in range(len(dict)):
  93. fname = dict[i]['path']
  94. if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
  95. is_lora = True
  96. is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
  97. is_safetensors = re.match(".*\.safetensors", fname)
  98. is_pt = re.match(".*\.pt", fname)
  99. is_ggml = re.match("ggml.*\.bin", fname)
  100. is_tokenizer = re.match("tokenizer.*\.model", fname)
  101. is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
  102. if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)):
  103. if 'lfs' in dict[i]:
  104. sha256.append([fname, dict[i]['lfs']['oid']])
  105. if is_text:
  106. links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
  107. classifications.append('text')
  108. continue
  109. if not text_only:
  110. links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
  111. if is_safetensors:
  112. has_safetensors = True
  113. classifications.append('safetensors')
  114. elif is_pytorch:
  115. has_pytorch = True
  116. classifications.append('pytorch')
  117. elif is_pt:
  118. has_pt = True
  119. classifications.append('pt')
  120. elif is_ggml:
  121. has_ggml = True
  122. classifications.append('ggml')
  123. cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
  124. cursor = base64.b64encode(cursor)
  125. cursor = cursor.replace(b'=', b'%3D')
  126. # If both pytorch and safetensors are available, download safetensors only
  127. if (has_pytorch or has_pt) and has_safetensors:
  128. for i in range(len(classifications) - 1, -1, -1):
  129. if classifications[i] in ['pytorch', 'pt']:
  130. links.pop(i)
  131. return links, sha256, is_lora
  132. def get_output_folder(model, branch, is_lora, base_folder=None):
  133. if base_folder is None:
  134. base_folder = 'models' if not is_lora else 'loras'
  135. output_folder = f"{'_'.join(model.split('/')[-2:])}"
  136. if branch != 'main':
  137. output_folder += f'_{branch}'
  138. output_folder = Path(base_folder) / output_folder
  139. return output_folder
  140. def get_single_file(url, output_folder, start_from_scratch=False):
  141. filename = Path(url.rsplit('/', 1)[1])
  142. output_path = output_folder / filename
  143. if output_path.exists() and not start_from_scratch:
  144. # Check if the file has already been downloaded completely
  145. r = requests.get(url, stream=True)
  146. total_size = int(r.headers.get('content-length', 0))
  147. if output_path.stat().st_size >= total_size:
  148. return
  149. # Otherwise, resume the download from where it left off
  150. headers = {'Range': f'bytes={output_path.stat().st_size}-'}
  151. mode = 'ab'
  152. else:
  153. headers = {}
  154. mode = 'wb'
  155. r = requests.get(url, stream=True, headers=headers)
  156. with open(output_path, mode) as f:
  157. total_size = int(r.headers.get('content-length', 0))
  158. block_size = 1024
  159. with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
  160. for data in r.iter_content(block_size):
  161. t.update(len(data))
  162. f.write(data)
  163. def start_download_threads(file_list, output_folder, start_from_scratch=False, threads=1):
  164. thread_map(lambda url: get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
  165. def download_model_files(model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1):
  166. # Creating the folder and writing the metadata
  167. if not output_folder.exists():
  168. output_folder.mkdir()
  169. with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
  170. f.write(f'url: https://huggingface.co/{model}\n')
  171. f.write(f'branch: {branch}\n')
  172. f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
  173. sha256_str = ''
  174. for i in range(len(sha256)):
  175. sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
  176. if sha256_str != '':
  177. f.write(f'sha256sum:\n{sha256_str}')
  178. # Downloading the files
  179. print(f"Downloading the model to {output_folder}")
  180. start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
  181. def check_model_files(model, branch, links, sha256, output_folder):
  182. # Validate the checksums
  183. validated = True
  184. for i in range(len(sha256)):
  185. fpath = (output_folder / sha256[i][0])
  186. if not fpath.exists():
  187. print(f"The following file is missing: {fpath}")
  188. validated = False
  189. continue
  190. with open(output_folder / sha256[i][0], "rb") as f:
  191. bytes = f.read()
  192. file_hash = hashlib.sha256(bytes).hexdigest()
  193. if file_hash != sha256[i][1]:
  194. print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}')
  195. validated = False
  196. else:
  197. print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}')
  198. if validated:
  199. print('[+] Validated checksums of all model files!')
  200. else:
  201. print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
  202. if __name__ == '__main__':
  203. branch = args.branch
  204. model = args.MODEL
  205. if model is None:
  206. model, branch = select_model_from_default_options()
  207. # Cleaning up the model/branch names
  208. try:
  209. model, branch = sanitize_model_and_branch_names(model, branch)
  210. except ValueError as err_branch:
  211. print(f"Error: {err_branch}")
  212. sys.exit()
  213. # Getting the download links from Hugging Face
  214. links, sha256, is_lora = get_download_links_from_huggingface(model, branch, text_only=args.text_only)
  215. # Getting the output folder
  216. output_folder = get_output_folder(model, branch, is_lora, base_folder=args.output)
  217. if args.check:
  218. # Check previously downloaded files
  219. check_model_files(model, branch, links, sha256, output_folder)
  220. else:
  221. # Download files
  222. download_model_files(model, branch, links, sha256, output_folder, threads=args.threads)