download-model.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. '''
  2. Downloads models from Hugging Face to models/model-name.
  3. Example:
  4. python download-model.py facebook/opt-1.3b
  5. '''
  6. import requests
  7. from bs4 import BeautifulSoup
  8. import multiprocessing
  9. import tqdm
  10. import sys
  11. import argparse
  12. from pathlib import Path
  13. import re
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument('MODEL', type=str)
  16. parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
  17. parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
  18. args = parser.parse_args()
  19. def get_file(args):
  20. url = args[0]
  21. output_folder = args[1]
  22. idx = args[2]
  23. tot = args[3]
  24. print(f"Downloading file {idx} of {tot}...")
  25. r = requests.get(url, stream=True)
  26. with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
  27. total_size = int(r.headers.get('content-length', 0))
  28. block_size = 1024
  29. t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True)
  30. for data in r.iter_content(block_size):
  31. t.update(len(data))
  32. f.write(data)
  33. t.close()
  34. def sanitize_branch_name(branch_name):
  35. pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
  36. if pattern.match(branch_name):
  37. return branch_name
  38. else:
  39. raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
  40. if __name__ == '__main__':
  41. model = args.MODEL
  42. if model[-1] == '/':
  43. model = model[:-1]
  44. branch = args.branch
  45. if args.branch is None:
  46. branch = 'main'
  47. else:
  48. try:
  49. branch_name = args.branch
  50. branch = sanitize_branch_name(branch_name)
  51. except ValueError as err_branch:
  52. print(f"Error: {err_branch}")
  53. sys.exit()
  54. url = f'https://huggingface.co/{model}/tree/{branch}'
  55. if branch != 'main':
  56. output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
  57. else:
  58. output_folder = Path("models") / model.split('/')[-1]
  59. if not output_folder.exists():
  60. output_folder.mkdir()
  61. # Finding the relevant files to download
  62. page = requests.get(url)
  63. soup = BeautifulSoup(page.content, 'html.parser')
  64. links = soup.find_all('a')
  65. downloads = []
  66. for link in links:
  67. href = link.get('href')[1:]
  68. if href.startswith(f'{model}/resolve/{branch}'):
  69. is_pytorch = href.endswith('.bin') and 'pytorch_model' in href
  70. is_safetensors = href.endswith('.safetensors') and 'model' in href
  71. if href.endswith(('.json', '.txt')) or is_pytorch or is_safetensors:
  72. downloads.append(f'https://huggingface.co/{href}')
  73. # Downloading the files
  74. print(f"Downloading the model to {output_folder}")
  75. pool = multiprocessing.Pool(processes=args.threads)
  76. results = pool.map(get_file, [[downloads[i], output_folder, i+1, len(downloads)] for i in range(len(downloads))])
  77. pool.close()
  78. pool.join()