Browse Source

Properly scrape huggingface for download links (for #122)

oobabooga 2 years ago
parent
commit
fe1771157f
2 changed files with 46 additions and 39 deletions
  1. 46 38
      download-model.py
  2. 0 1
      requirements.txt

+ 46 - 38
download-model.py

@@ -6,6 +6,7 @@ python download-model.py facebook/opt-1.3b
 
 '''
 import argparse
+import json
 import multiprocessing
 import re
 import sys
@@ -13,7 +14,6 @@ from pathlib import Path
 
 import requests
 import tqdm
-from bs4 import BeautifulSoup
 
 parser = argparse.ArgumentParser()
 parser.add_argument('MODEL', type=str, default=None, nargs='?')
@@ -90,54 +90,32 @@ facebook/opt-1.3b
 
     return model, branch
 
-if __name__ == '__main__':
-    model = args.MODEL
-    branch = args.branch
-    if model is None:
-        model, branch = select_model_from_default_options()
-    else:
-        if model[-1] == '/':
-            model = model[:-1]
-            branch = args.branch
-        if branch is None:
-            branch = "main"
-        else:
-            try:
-                branch = sanitize_branch_name(branch)
-            except ValueError as err_branch:
-                print(f"Error: {err_branch}")
-                sys.exit()
-    url = f'https://huggingface.co/{model}/tree/{branch}'
-    if branch != 'main':
-        output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
-    else:
-        output_folder = Path("models") / model.split('/')[-1]
-    if not output_folder.exists():
-        output_folder.mkdir()
+def get_download_links_from_huggingface(model, branch):
+    base = "https://huggingface.co"
+    page = f"/api/models/{model}/tree/{branch}?cursor="
 
-    # Finding the relevant files to download
-    page = requests.get(url) 
-    soup = BeautifulSoup(page.content, 'html.parser') 
-    links = soup.find_all('a')
-    downloads = []
+    links = []
     classifications = []
     has_pytorch = False
     has_safetensors = False
-    for link in links:
-        href = link.get('href')[1:]
-        if href.startswith(f'{model}/resolve/{branch}'):
-            fname = Path(href).name
+    while page is not None:
+        content = requests.get(f"{base}{page}").content
+        dict = json.loads(content)
+
+        for i in range(len(dict['items'])):
+            fname = dict['items'][i]['path']
+
             is_pytorch = re.match("pytorch_model.*\.bin", fname)
             is_safetensors = re.match("model.*\.safetensors", fname)
             is_text = re.match(".*\.(txt|json)", fname)
 
             if is_text or is_safetensors or is_pytorch:
                 if is_text:
-                    downloads.append(f'https://huggingface.co/{href}')
+                    links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
                     classifications.append('text')
                     continue
                 if not args.text_only:
-                    downloads.append(f'https://huggingface.co/{href}')
+                    links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
                     if is_safetensors:
                         has_safetensors = True
                         classifications.append('safetensors')
@@ -145,15 +123,45 @@ if __name__ == '__main__':
                         has_pytorch = True
                         classifications.append('pytorch')
 
+        page = dict['nextUrl']
+
     # If both pytorch and safetensors are available, download safetensors only
     if has_pytorch and has_safetensors:
         for i in range(len(classifications)-1, -1, -1):
             if classifications[i] == 'pytorch':
-                downloads.pop(i)
+                links.pop(i)
+
+    return links
+
+if __name__ == '__main__':
+    model = args.MODEL
+    branch = args.branch
+    if model is None:
+        model, branch = select_model_from_default_options()
+    else:
+        if model[-1] == '/':
+            model = model[:-1]
+            branch = args.branch
+        if branch is None:
+            branch = "main"
+        else:
+            try:
+                branch = sanitize_branch_name(branch)
+            except ValueError as err_branch:
+                print(f"Error: {err_branch}")
+                sys.exit()
+    if branch != 'main':
+        output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
+    else:
+        output_folder = Path("models") / model.split('/')[-1]
+    if not output_folder.exists():
+        output_folder.mkdir()
+
+    links = get_download_links_from_huggingface(model, branch)
 
     # Downloading the files
     print(f"Downloading the model to {output_folder}")
     pool = multiprocessing.Pool(processes=args.threads)
-    results = pool.map(get_file, [[downloads[i], output_folder, i+1, len(downloads)] for i in range(len(downloads))])
+    results = pool.map(get_file, [[links[i], output_folder, i+1, len(links)] for i in range(len(links))])
     pool.close()
     pool.join()

+ 0 - 1
requirements.txt

@@ -1,5 +1,4 @@
 accelerate==0.16.0
-beautifulsoup4
 bitsandbytes==0.37.0
 gradio==3.18.0
 numpy