download-model.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. '''
  2. Downloads models from Hugging Face to models/model-name.
  3. Example:
  4. python download-model.py facebook/opt-1.3b
  5. '''
  6. import requests
  7. from bs4 import BeautifulSoup
  8. import multiprocessing
  9. import tqdm
  10. import sys
  11. import argparse
  12. from pathlib import Path
  13. import re
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument('MODEL', type=str)
  16. parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
  17. args = parser.parse_args()
  18. def get_file(args):
  19. url = args[0]
  20. output_folder = args[1]
  21. r = requests.get(url, stream=True)
  22. with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
  23. total_size = int(r.headers.get('content-length', 0))
  24. block_size = 1024
  25. t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True)
  26. for data in r.iter_content(block_size):
  27. t.update(len(data))
  28. f.write(data)
  29. t.close()
  30. def sanitize_branch_name(branch_name):
  31. pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
  32. if pattern.match(branch_name):
  33. return branch_name
  34. else:
  35. raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
  36. if __name__ == '__main__':
  37. model = args.model
  38. if model[-1] == '/':
  39. model = model[:-1]
  40. branch = args.branch
  41. if args.branch is None:
  42. branch = 'main'
  43. else:
  44. try:
  45. branch_name = args.branch
  46. branch = sanitize_branch_name(branch_name)
  47. except ValueError as err_branch:
  48. print(f"Error: {err_branch}")
  49. sys.exit()
  50. url = f'https://huggingface.co/{model}/tree/{branch}'
  51. if branch != 'main':
  52. output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
  53. else:
  54. output_folder = Path("models") / model.split('/')[-1]
  55. if not output_folder.exists():
  56. output_folder.mkdir()
  57. # Finding the relevant files to download
  58. page = requests.get(url)
  59. soup = BeautifulSoup(page.content, 'html.parser')
  60. links = soup.find_all('a')
  61. downloads = []
  62. for link in links:
  63. href = link.get('href')[1:]
  64. if href.startswith(f'{model}/resolve/{branch}'):
  65. if href.endswith(('.json', '.txt')) or (href.endswith('.bin') and 'pytorch_model' in href):
  66. downloads.append(f'https://huggingface.co/{href}')
  67. # Downloading the files
  68. print(f"Downloading the model to {output_folder}...")
  69. pool = multiprocessing.Pool(processes=4)
  70. results = pool.map(get_file, [[downloads[i], output_folder] for i in range(len(downloads))])
  71. pool.close()
  72. pool.join()