download-model.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. '''
  2. Downloads models from Hugging Face to models/model-name.
  3. Example:
  4. python download-model.py facebook/opt-1.3b
  5. '''
  6. import requests
  7. from bs4 import BeautifulSoup
  8. import multiprocessing
  9. import os
  10. import tqdm
  11. from sys import argv
  12. def get_file(args):
  13. url = args[0]
  14. output_folder = args[1]
  15. r = requests.get(url, stream=True)
  16. with open(f"{output_folder}/{url.split('/')[-1]}", 'wb') as f:
  17. total_size = int(r.headers.get('content-length', 0))
  18. block_size = 1024
  19. t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True)
  20. for data in r.iter_content(block_size):
  21. t.update(len(data))
  22. f.write(data)
  23. t.close()
  24. model = argv[1]
  25. if model.endswith('/'):
  26. model = model[:-1]
  27. url = f'https://huggingface.co/{model}/tree/main'
  28. output_folder = f"models/{model.split('/')[-1]}"
  29. if not os.path.exists(output_folder):
  30. os.mkdir(output_folder)
  31. # Finding the relevant files to download
  32. page = requests.get(url)
  33. soup = BeautifulSoup(page.content, 'html.parser')
  34. links = soup.find_all('a')
  35. downloads = []
  36. for link in links:
  37. href = link.get('href')[1:]
  38. if href.startswith(f'{model}/resolve/main'):
  39. if href.endswith(('.json', '.txt')) or (href.endswith('.bin') and 'pytorch_model' in href):
  40. downloads.append(f'https://huggingface.co/{href}')
  41. # Downloading the files
  42. print(f"Downloading the model to {output_folder}...")
  43. pool = multiprocessing.Pool(processes=4)
  44. results = pool.map(get_file, [[downloads[i], output_folder] for i in range(len(downloads))])
  45. pool.close()
  46. pool.join()