download-model.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. '''
  2. Downloads models from Hugging Face to models/model-name.
  3. Example:
  4. python download-model.py facebook/opt-1.3b
  5. '''
  6. import argparse
  7. import base64
  8. import datetime
  9. import json
  10. import multiprocessing
  11. import re
  12. import sys
  13. from pathlib import Path
  14. import requests
  15. import tqdm
  16. parser = argparse.ArgumentParser()
  17. parser.add_argument('MODEL', type=str, default=None, nargs='?')
  18. parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
  19. parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
  20. parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
  21. parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
  22. args = parser.parse_args()
  23. def get_file(args):
  24. url = args[0]
  25. output_folder = args[1]
  26. idx = args[2]
  27. tot = args[3]
  28. print(f"Downloading file {idx} of {tot}...")
  29. r = requests.get(url, stream=True)
  30. with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
  31. total_size = int(r.headers.get('content-length', 0))
  32. block_size = 1024
  33. t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True)
  34. for data in r.iter_content(block_size):
  35. t.update(len(data))
  36. f.write(data)
  37. t.close()
  38. def sanitize_branch_name(branch_name):
  39. pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
  40. if pattern.match(branch_name):
  41. return branch_name
  42. else:
  43. raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
  44. def select_model_from_default_options():
  45. models = {
  46. "Pygmalion 6B original": ("PygmalionAI", "pygmalion-6b", "b8344bb4eb76a437797ad3b19420a13922aaabe1"),
  47. "Pygmalion 6B main": ("PygmalionAI", "pygmalion-6b", "main"),
  48. "Pygmalion 6B dev": ("PygmalionAI", "pygmalion-6b", "dev"),
  49. "Pygmalion 2.7B": ("PygmalionAI", "pygmalion-2.7b", "main"),
  50. "Pygmalion 1.3B": ("PygmalionAI", "pygmalion-1.3b", "main"),
  51. "Pygmalion 350m": ("PygmalionAI", "pygmalion-350m", "main"),
  52. "OPT 6.7b": ("facebook", "opt-6.7b", "main"),
  53. "OPT 2.7b": ("facebook", "opt-2.7b", "main"),
  54. "OPT 1.3b": ("facebook", "opt-1.3b", "main"),
  55. "OPT 350m": ("facebook", "opt-350m", "main"),
  56. }
  57. choices = {}
  58. print("Select the model that you want to download:\n")
  59. for i,name in enumerate(models):
  60. char = chr(ord('A')+i)
  61. choices[char] = name
  62. print(f"{char}) {name}")
  63. char = chr(ord('A')+len(models))
  64. print(f"{char}) None of the above")
  65. print()
  66. print("Input> ", end='')
  67. choice = input()[0].strip().upper()
  68. if choice == char:
  69. print("""\nThen type the name of your desired Hugging Face model in the format organization/name.
  70. Examples:
  71. PygmalionAI/pygmalion-6b
  72. facebook/opt-1.3b
  73. """)
  74. print("Input> ", end='')
  75. model = input()
  76. branch = "main"
  77. else:
  78. arr = models[choices[choice]]
  79. model = f"{arr[0]}/{arr[1]}"
  80. branch = arr[2]
  81. return model, branch
  82. def get_download_links_from_huggingface(model, branch):
  83. base = "https://huggingface.co"
  84. page = f"/api/models/{model}/tree/{branch}?cursor="
  85. cursor = b""
  86. links = []
  87. classifications = []
  88. has_pytorch = False
  89. has_pt = False
  90. has_safetensors = False
  91. is_lora = False
  92. while True:
  93. content = requests.get(f"{base}{page}{cursor.decode()}").content
  94. dict = json.loads(content)
  95. if len(dict) == 0:
  96. break
  97. for i in range(len(dict)):
  98. fname = dict[i]['path']
  99. if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
  100. is_lora = True
  101. is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
  102. is_safetensors = re.match(".*\.safetensors", fname)
  103. is_pt = re.match(".*\.pt", fname)
  104. is_tokenizer = re.match("tokenizer.*\.model", fname)
  105. is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
  106. if any((is_pytorch, is_safetensors, is_pt, is_tokenizer, is_text)):
  107. if is_text:
  108. links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
  109. classifications.append('text')
  110. continue
  111. if not args.text_only:
  112. links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
  113. if is_safetensors:
  114. has_safetensors = True
  115. classifications.append('safetensors')
  116. elif is_pytorch:
  117. has_pytorch = True
  118. classifications.append('pytorch')
  119. elif is_pt:
  120. has_pt = True
  121. classifications.append('pt')
  122. cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
  123. cursor = base64.b64encode(cursor)
  124. cursor = cursor.replace(b'=', b'%3D')
  125. # If both pytorch and safetensors are available, download safetensors only
  126. if (has_pytorch or has_pt) and has_safetensors:
  127. for i in range(len(classifications)-1, -1, -1):
  128. if classifications[i] in ['pytorch', 'pt']:
  129. links.pop(i)
  130. return links, is_lora
  131. if __name__ == '__main__':
  132. model = args.MODEL
  133. branch = args.branch
  134. if model is None:
  135. model, branch = select_model_from_default_options()
  136. else:
  137. if model[-1] == '/':
  138. model = model[:-1]
  139. branch = args.branch
  140. if branch is None:
  141. branch = "main"
  142. else:
  143. try:
  144. branch = sanitize_branch_name(branch)
  145. except ValueError as err_branch:
  146. print(f"Error: {err_branch}")
  147. sys.exit()
  148. links, is_lora = get_download_links_from_huggingface(model, branch)
  149. if args.output is not None:
  150. base_folder = args.output
  151. else:
  152. base_folder = 'models' if not is_lora else 'loras'
  153. output_folder = f"{'_'.join(model.split('/')[-2:])}"
  154. if branch != 'main':
  155. output_folder += f'_{branch}'
  156. # Creating the folder and writing the metadata
  157. output_folder = Path(base_folder) / output_folder
  158. if not output_folder.exists():
  159. output_folder.mkdir()
  160. with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
  161. f.write(f'url: https://huggingface.co/{model}\n')
  162. f.write(f'branch: {branch}\n')
  163. f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
  164. # Downloading the files
  165. print(f"Downloading the model to {output_folder}")
  166. pool = multiprocessing.Pool(processes=args.threads)
  167. results = pool.map(get_file, [[links[i], output_folder, i+1, len(links)] for i in range(len(links))])
  168. pool.close()
  169. pool.join()