download-model.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. '''
  2. Downloads models from Hugging Face to models/model-name.
  3. Example:
  4. python download-model.py facebook/opt-1.3b
  5. '''
  6. import argparse
  7. import base64
  8. import json
  9. import multiprocessing
  10. import re
  11. import sys
  12. from pathlib import Path
  13. import requests
  14. import tqdm
  15. def get_file(args):
  16. url = args[0]
  17. output_folder = args[1]
  18. idx = args[2]
  19. tot = args[3]
  20. print(f"Downloading file {idx} of {tot}...")
  21. r = requests.get(url, stream=True)
  22. with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
  23. total_size = int(r.headers.get('content-length', 0))
  24. block_size = 1024
  25. t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True)
  26. for data in r.iter_content(block_size):
  27. t.update(len(data))
  28. f.write(data)
  29. t.close()
  30. def sanitize_branch_name(branch_name):
  31. pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
  32. if pattern.match(branch_name):
  33. return branch_name
  34. else:
  35. raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
  36. def select_model_from_default_options():
  37. models = {
  38. "Pygmalion 6B original": ("PygmalionAI", "pygmalion-6b", "b8344bb4eb76a437797ad3b19420a13922aaabe1"),
  39. "Pygmalion 6B main": ("PygmalionAI", "pygmalion-6b", "main"),
  40. "Pygmalion 6B dev": ("PygmalionAI", "pygmalion-6b", "dev"),
  41. "Pygmalion 2.7B": ("PygmalionAI", "pygmalion-2.7b", "main"),
  42. "Pygmalion 1.3B": ("PygmalionAI", "pygmalion-1.3b", "main"),
  43. "Pygmalion 350m": ("PygmalionAI", "pygmalion-350m", "main"),
  44. "OPT 6.7b": ("facebook", "opt-6.7b", "main"),
  45. "OPT 2.7b": ("facebook", "opt-2.7b", "main"),
  46. "OPT 1.3b": ("facebook", "opt-1.3b", "main"),
  47. "OPT 350m": ("facebook", "opt-350m", "main"),
  48. }
  49. choices = {}
  50. print("Select the model that you want to download:\n")
  51. for i,name in enumerate(models):
  52. char = chr(ord('A')+i)
  53. choices[char] = name
  54. print(f"{char}) {name}")
  55. char = chr(ord('A')+len(models))
  56. print(f"{char}) None of the above")
  57. print()
  58. print("Input> ", end='')
  59. choice = input()[0].strip().upper()
  60. if choice == char:
  61. print("""\nThen type the name of your desired Hugging Face model in the format organization/name.
  62. Examples:
  63. PygmalionAI/pygmalion-6b
  64. facebook/opt-1.3b
  65. """)
  66. print("Input> ", end='')
  67. model = input()
  68. branch = "main"
  69. else:
  70. arr = models[choices[choice]]
  71. model = f"{arr[0]}/{arr[1]}"
  72. branch = arr[2]
  73. return model, branch
  74. def get_download_links_from_huggingface(model, branch):
  75. base = "https://huggingface.co"
  76. page = f"/api/models/{model}/tree/{branch}?cursor="
  77. cursor = b""
  78. links = []
  79. classifications = []
  80. has_pytorch = False
  81. has_pt = False
  82. has_safetensors = False
  83. is_lora = False
  84. while True:
  85. content = requests.get(f"{base}{page}{cursor.decode()}").content
  86. dict = json.loads(content)
  87. if len(dict) == 0:
  88. break
  89. for i in range(len(dict)):
  90. fname = dict[i]['path']
  91. if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
  92. is_lora = True
  93. is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
  94. is_safetensors = re.match(".*\.safetensors", fname)
  95. is_pt = re.match(".*\.pt", fname)
  96. is_tokenizer = re.match("tokenizer.*\.model", fname)
  97. is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
  98. if any((is_pytorch, is_safetensors, is_pt, is_tokenizer, is_text)):
  99. if is_text:
  100. links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
  101. classifications.append('text')
  102. continue
  103. if not args.text_only:
  104. links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
  105. if is_safetensors:
  106. has_safetensors = True
  107. classifications.append('safetensors')
  108. elif is_pytorch:
  109. has_pytorch = True
  110. classifications.append('pytorch')
  111. elif is_pt:
  112. has_pt = True
  113. classifications.append('pt')
  114. cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
  115. cursor = base64.b64encode(cursor)
  116. cursor = cursor.replace(b'=', b'%3D')
  117. # If both pytorch and safetensors are available, download safetensors only
  118. if (has_pytorch or has_pt) and has_safetensors:
  119. for i in range(len(classifications)-1, -1, -1):
  120. if classifications[i] in ['pytorch', 'pt']:
  121. links.pop(i)
  122. return links, is_lora
  123. def download_files(file_list, output_folder, num_processes=8):
  124. with multiprocessing.Pool(processes=num_processes) as pool:
  125. args = [(url, output_folder, idx+1, len(file_list)) for idx, url in enumerate(file_list)]
  126. for _ in tqdm.tqdm(pool.imap_unordered(get_file, args), total=len(args)):
  127. pass
  128. pool.close()
  129. pool.join()
  130. if __name__ == '__main__':
  131. parser = argparse.ArgumentParser()
  132. parser.add_argument('MODEL', type=str, default=None, nargs='?')
  133. parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
  134. parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
  135. parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
  136. args = parser.parse_args()
  137. model = args.MODEL
  138. branch = args.branch
  139. if model is None:
  140. model, branch = select_model_from_default_options()
  141. else:
  142. if model[-1] == '/':
  143. model = model[:-1]
  144. branch = args.branch
  145. if branch is None:
  146. branch = "main"
  147. else:
  148. try:
  149. branch = sanitize_branch_name(branch)
  150. except ValueError as err_branch:
  151. print(f"Error: {err_branch}")
  152. sys.exit()
  153. links, is_lora = get_download_links_from_huggingface(model, branch)
  154. base_folder = 'models' if not is_lora else 'loras'
  155. if branch != 'main':
  156. output_folder = Path(base_folder) / (model.split('/')[-1] + f'_{branch}')
  157. else:
  158. output_folder = Path(base_folder) / model.split('/')[-1]
  159. if not output_folder.exists():
  160. output_folder.mkdir()
  161. # Downloading the files
  162. print(f"Downloading the model to {output_folder}")
  163. download_files(links, output_folder, num_processes=args.threads)