download-model.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. '''
  2. Downloads models from Hugging Face to models/model-name.
  3. Example:
  4. python download-model.py facebook/opt-1.3b
  5. '''
  6. import argparse
  7. import base64
  8. import datetime
  9. import json
  10. import re
  11. import sys
  12. from pathlib import Path
  13. import requests
  14. import tqdm
  15. from tqdm.contrib.concurrent import thread_map
  16. import hashlib
  17. parser = argparse.ArgumentParser()
  18. parser.add_argument('MODEL', type=str, default=None, nargs='?')
  19. parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
  20. parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
  21. parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
  22. parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
  23. parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
  24. args = parser.parse_args()
  25. def get_file(url, output_folder):
  26. filename = Path(url.rsplit('/', 1)[1])
  27. output_path = output_folder / filename
  28. if output_path.exists() and not args.clean:
  29. # Check if the file has already been downloaded completely
  30. r = requests.head(url)
  31. total_size = int(r.headers.get('content-length', 0))
  32. if output_path.stat().st_size >= total_size:
  33. return
  34. # Otherwise, resume the download from where it left off
  35. headers = {'Range': f'bytes={output_path.stat().st_size}-'}
  36. mode = 'ab'
  37. else:
  38. headers = {}
  39. mode = 'wb'
  40. r = requests.get(url, stream=True, headers=headers)
  41. with open(output_path, mode) as f:
  42. total_size = int(r.headers.get('content-length', 0))
  43. block_size = 1024
  44. with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
  45. for data in r.iter_content(block_size):
  46. t.update(len(data))
  47. f.write(data)
  48. def sanitize_branch_name(branch_name):
  49. pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
  50. if pattern.match(branch_name):
  51. return branch_name
  52. else:
  53. raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
  54. def select_model_from_default_options():
  55. models = {
  56. "Pygmalion 6B original": ("PygmalionAI", "pygmalion-6b", "b8344bb4eb76a437797ad3b19420a13922aaabe1"),
  57. "Pygmalion 6B main": ("PygmalionAI", "pygmalion-6b", "main"),
  58. "Pygmalion 6B dev": ("PygmalionAI", "pygmalion-6b", "dev"),
  59. "Pygmalion 2.7B": ("PygmalionAI", "pygmalion-2.7b", "main"),
  60. "Pygmalion 1.3B": ("PygmalionAI", "pygmalion-1.3b", "main"),
  61. "Pygmalion 350m": ("PygmalionAI", "pygmalion-350m", "main"),
  62. "OPT 6.7b": ("facebook", "opt-6.7b", "main"),
  63. "OPT 2.7b": ("facebook", "opt-2.7b", "main"),
  64. "OPT 1.3b": ("facebook", "opt-1.3b", "main"),
  65. "OPT 350m": ("facebook", "opt-350m", "main"),
  66. }
  67. choices = {}
  68. print("Select the model that you want to download:\n")
  69. for i,name in enumerate(models):
  70. char = chr(ord('A')+i)
  71. choices[char] = name
  72. print(f"{char}) {name}")
  73. char = chr(ord('A')+len(models))
  74. print(f"{char}) None of the above")
  75. print()
  76. print("Input> ", end='')
  77. choice = input()[0].strip().upper()
  78. if choice == char:
  79. print("""\nThen type the name of your desired Hugging Face model in the format organization/name.
  80. Examples:
  81. PygmalionAI/pygmalion-6b
  82. facebook/opt-1.3b
  83. """)
  84. print("Input> ", end='')
  85. model = input()
  86. branch = "main"
  87. else:
  88. arr = models[choices[choice]]
  89. model = f"{arr[0]}/{arr[1]}"
  90. branch = arr[2]
  91. return model, branch
  92. def get_download_links_from_huggingface(model, branch):
  93. base = "https://huggingface.co"
  94. page = f"/api/models/{model}/tree/{branch}?cursor="
  95. cursor = b""
  96. links = []
  97. sha256 = []
  98. classifications = []
  99. has_pytorch = False
  100. has_pt = False
  101. has_safetensors = False
  102. is_lora = False
  103. while True:
  104. content = requests.get(f"{base}{page}{cursor.decode()}").content
  105. dict = json.loads(content)
  106. if len(dict) == 0:
  107. break
  108. for i in range(len(dict)):
  109. fname = dict[i]['path']
  110. if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
  111. is_lora = True
  112. is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
  113. is_safetensors = re.match(".*\.safetensors", fname)
  114. is_pt = re.match(".*\.pt", fname)
  115. is_tokenizer = re.match("tokenizer.*\.model", fname)
  116. is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
  117. if any((is_pytorch, is_safetensors, is_pt, is_tokenizer, is_text)):
  118. if 'lfs' in dict[i]:
  119. sha256.append([fname, dict[i]['lfs']['oid']])
  120. if is_text:
  121. links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
  122. classifications.append('text')
  123. continue
  124. if not args.text_only:
  125. links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
  126. if is_safetensors:
  127. has_safetensors = True
  128. classifications.append('safetensors')
  129. elif is_pytorch:
  130. has_pytorch = True
  131. classifications.append('pytorch')
  132. elif is_pt:
  133. has_pt = True
  134. classifications.append('pt')
  135. cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
  136. cursor = base64.b64encode(cursor)
  137. cursor = cursor.replace(b'=', b'%3D')
  138. # If both pytorch and safetensors are available, download safetensors only
  139. if (has_pytorch or has_pt) and has_safetensors:
  140. for i in range(len(classifications)-1, -1, -1):
  141. if classifications[i] in ['pytorch', 'pt']:
  142. links.pop(i)
  143. return links, sha256, is_lora
  144. def download_files(file_list, output_folder, num_threads=8):
  145. thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
  146. if __name__ == '__main__':
  147. model = args.MODEL
  148. branch = args.branch
  149. if model is None:
  150. model, branch = select_model_from_default_options()
  151. else:
  152. if model[-1] == '/':
  153. model = model[:-1]
  154. branch = args.branch
  155. if branch is None:
  156. branch = "main"
  157. else:
  158. try:
  159. branch = sanitize_branch_name(branch)
  160. except ValueError as err_branch:
  161. print(f"Error: {err_branch}")
  162. sys.exit()
  163. links, sha256, is_lora = get_download_links_from_huggingface(model, branch)
  164. if args.output is not None:
  165. base_folder = args.output
  166. else:
  167. base_folder = 'models' if not is_lora else 'loras'
  168. output_folder = f"{'_'.join(model.split('/')[-2:])}"
  169. if branch != 'main':
  170. output_folder += f'_{branch}'
  171. # Creating the folder and writing the metadata
  172. output_folder = Path(base_folder) / output_folder
  173. if not output_folder.exists():
  174. output_folder.mkdir()
  175. with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
  176. f.write(f'url: https://huggingface.co/{model}\n')
  177. f.write(f'branch: {branch}\n')
  178. f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
  179. sha256_str = ''
  180. for i in range(len(sha256)):
  181. sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
  182. if sha256_str != '':
  183. f.write(f'sha256sum:\n{sha256_str}')
  184. # Downloading the files
  185. print(f"Downloading the model to {output_folder}")
  186. download_files(links, output_folder, args.threads)
  187. # Validate the checksums
  188. validated = True
  189. for i in range(len(sha256)):
  190. with open(output_folder / sha256[i][0], "rb") as f:
  191. bytes = f.read()
  192. file_hash = hashlib.sha256(bytes).hexdigest()
  193. if file_hash != sha256[i][1]:
  194. print(f'[!] Checksum for {sha256[i][0]} failed!')
  195. validated = False
  196. if validated:
  197. print('[+] Validated checksums of all model files!')
  198. else:
  199. print('[-] Rerun the download-model.py with --clean flag')