download-model.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. '''
  2. Downloads models from Hugging Face to models/model-name.
  3. Example:
  4. python download-model.py facebook/opt-1.3b
  5. '''
  6. import argparse
  7. import multiprocessing
  8. import re
  9. import sys
  10. from pathlib import Path
  11. import requests
  12. import tqdm
  13. from bs4 import BeautifulSoup
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument('MODEL', type=str)
  16. parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
  17. parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
  18. args = parser.parse_args()
  19. def get_file(args):
  20. url = args[0]
  21. output_folder = args[1]
  22. idx = args[2]
  23. tot = args[3]
  24. print(f"Downloading file {idx} of {tot}...")
  25. r = requests.get(url, stream=True)
  26. with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
  27. total_size = int(r.headers.get('content-length', 0))
  28. block_size = 1024
  29. t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True)
  30. for data in r.iter_content(block_size):
  31. t.update(len(data))
  32. f.write(data)
  33. t.close()
  34. def sanitize_branch_name(branch_name):
  35. pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
  36. if pattern.match(branch_name):
  37. return branch_name
  38. else:
  39. raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
  40. if __name__ == '__main__':
  41. model = args.MODEL
  42. if model[-1] == '/':
  43. model = model[:-1]
  44. branch = args.branch
  45. if args.branch is None:
  46. branch = 'main'
  47. else:
  48. try:
  49. branch_name = args.branch
  50. branch = sanitize_branch_name(branch_name)
  51. except ValueError as err_branch:
  52. print(f"Error: {err_branch}")
  53. sys.exit()
  54. url = f'https://huggingface.co/{model}/tree/{branch}'
  55. if branch != 'main':
  56. output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
  57. else:
  58. output_folder = Path("models") / model.split('/')[-1]
  59. if not output_folder.exists():
  60. output_folder.mkdir()
  61. # Finding the relevant files to download
  62. page = requests.get(url)
  63. soup = BeautifulSoup(page.content, 'html.parser')
  64. links = soup.find_all('a')
  65. downloads = []
  66. classifications = []
  67. has_pytorch = False
  68. has_safetensors = False
  69. for link in links:
  70. href = link.get('href')[1:]
  71. if href.startswith(f'{model}/resolve/{branch}'):
  72. fname = Path(href).name
  73. is_pytorch = re.match("pytorch_model.*\.bin", fname)
  74. is_safetensors = re.match("model.*\.safetensors", fname)
  75. is_text = re.match(".*\.(txt|json)", fname)
  76. if is_text or is_safetensors or is_pytorch:
  77. downloads.append(f'https://huggingface.co/{href}')
  78. if is_text:
  79. classifications.append('text')
  80. elif is_safetensors:
  81. has_safetensors = True
  82. classifications.append('safetensors')
  83. elif is_pytorch:
  84. has_pytorch = True
  85. classifications.append('pytorch')
  86. # If both pytorch and safetensors are available, download safetensors only
  87. if has_pytorch and has_safetensors:
  88. for i in range(len(classifications)-1, -1, -1):
  89. if classifications[i] == 'pytorch':
  90. downloads.pop(i)
  91. # Downloading the files
  92. print(f"Downloading the model to {output_folder}")
  93. pool = multiprocessing.Pool(processes=args.threads)
  94. results = pool.map(get_file, [[downloads[i], output_folder, i+1, len(downloads)] for i in range(len(downloads))])
  95. pool.close()
  96. pool.join()