Sfoglia il codice sorgente

Fix the download script for long lists of files on HF

oobabooga 2 anni fa
parent
commit
249c268176
1 ha cambiato i file con 11 aggiunte e 4 eliminazioni
  1. 11 4
      download-model.py

+ 11 - 4
download-model.py

@@ -5,7 +5,9 @@ Example:
 python download-model.py facebook/opt-1.3b
 
 '''
+
 import argparse
+import base64
 import json
 import multiprocessing
 import re
@@ -93,14 +95,18 @@ facebook/opt-1.3b
 def get_download_links_from_huggingface(model, branch):
     base = "https://huggingface.co"
     page = f"/api/models/{model}/tree/{branch}?cursor="
+    cursor = b""
 
     links = []
     classifications = []
     has_pytorch = False
     has_safetensors = False
-    while page is not None:
-        content = requests.get(f"{base}{page}").content
+    while True:
+        content = requests.get(f"{base}{page}{cursor.decode()}").content
+
         dict = json.loads(content)
+        if len(dict) == 0:
+            break
 
         for i in range(len(dict)):
             fname = dict[i]['path']
@@ -123,8 +129,9 @@ def get_download_links_from_huggingface(model, branch):
                         has_pytorch = True
                         classifications.append('pytorch')
 
-        #page = dict['nextUrl']
-        page = None
+        cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
+        cursor = base64.b64encode(cursor)
+        cursor = cursor.replace(b'=', b'%3D')
 
     # If both pytorch and safetensors are available, download safetensors only
     if has_pytorch and has_safetensors: