download-model.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. '''
  2. Downloads models from Hugging Face to models/model-name.
  3. Example:
  4. python download-model.py facebook/opt-1.3b
  5. '''
  6. import requests
  7. from bs4 import BeautifulSoup
  8. import multiprocessing
  9. import tqdm
  10. import sys
  11. from sys import argv
  12. import argparse
  13. from pathlib import Path
  14. import re
  15. parser = argparse.ArgumentParser()
  16. parser.add_argument('MODEL', type=str)
  17. parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
  18. args = parser.parse_args()
  19. def get_file(args):
  20. url = args[0]
  21. output_folder = args[1]
  22. r = requests.get(url, stream=True)
  23. with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
  24. total_size = int(r.headers.get('content-length', 0))
  25. block_size = 1024
  26. t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True)
  27. for data in r.iter_content(block_size):
  28. t.update(len(data))
  29. f.write(data)
  30. t.close()
  31. def sanitize_branch_name(branch_name):
  32. pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
  33. if pattern.match(branch_name):
  34. return branch_name
  35. else:
  36. raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
  37. if __name__ == '__main__':
  38. model = argv[1]
  39. if model[-1] == '/':
  40. model = model[:-1]
  41. branch = args.branch
  42. if args.branch is None:
  43. branch = 'main'
  44. else:
  45. try:
  46. branch_name = args.branch
  47. branch = sanitize_branch_name(branch_name)
  48. except ValueError as err_branch:
  49. print(f"Error: {err_branch}")
  50. sys.exit()
  51. url = f'https://huggingface.co/{model}/tree/{branch}'
  52. if branch != 'main':
  53. output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
  54. else:
  55. output_folder = Path("models") / model.split('/')[-1]
  56. if not output_folder.exists():
  57. output_folder.mkdir()
  58. # Finding the relevant files to download
  59. page = requests.get(url)
  60. soup = BeautifulSoup(page.content, 'html.parser')
  61. links = soup.find_all('a')
  62. downloads = []
  63. for link in links:
  64. href = link.get('href')[1:]
  65. if href.startswith(f'{model}/resolve/{branch}'):
  66. if href.endswith(('.json', '.txt')) or (href.endswith('.bin') and 'pytorch_model' in href):
  67. downloads.append(f'https://huggingface.co/{href}')
  68. # Downloading the files
  69. print(f"Downloading the model to {output_folder}...")
  70. pool = multiprocessing.Pool(processes=4)
  71. results = pool.map(get_file, [[downloads[i], output_folder] for i in range(len(downloads))])
  72. pool.close()
  73. pool.join()