Forráskód Böngészése

Fix the download script on windows (#6)

oobabooga 3 éve
szülő
commit
fcda5d7107
1 módosított fájl, 27 hozzáadás és 25 törlés
  1. 27 25
      download-model.py

+ 27 - 25
download-model.py

@@ -27,28 +27,30 @@ def get_file(args):
             f.write(data)
         t.close()
 
-model = argv[1]
-if model[-1] == '/':
-    model = model[:-1]
-url = f'https://huggingface.co/{model}/tree/main'
-output_folder = Path("models") / model.split('/')[-1]
-if not output_folder.exists():
-    output_folder.mkdir()
-
-# Finding the relevant files to download
-page = requests.get(url) 
-soup = BeautifulSoup(page.content, 'html.parser') 
-links = soup.find_all('a')
-downloads = []
-for link in links:
-    href = link.get('href')[1:]
-    if href.startswith(f'{model}/resolve/main'):
-        if href.endswith(('.json', '.txt')) or (href.endswith('.bin') and 'pytorch_model' in href):
-            downloads.append(f'https://huggingface.co/{href}')
-
-# Downloading the files
-print(f"Downloading the model to {output_folder}...")
-pool = multiprocessing.Pool(processes=4)
-results = pool.map(get_file, [[downloads[i], output_folder] for i in range(len(downloads))])
-pool.close()
-pool.join()
+if __name__ == '__main__':
+
+    model = argv[1]
+    if model[-1] == '/':
+        model = model[:-1]
+    url = f'https://huggingface.co/{model}/tree/main'
+    output_folder = Path("models") / model.split('/')[-1]
+    if not output_folder.exists():
+        output_folder.mkdir()
+
+    # Finding the relevant files to download
+    page = requests.get(url) 
+    soup = BeautifulSoup(page.content, 'html.parser') 
+    links = soup.find_all('a')
+    downloads = []
+    for link in links:
+        href = link.get('href')[1:]
+        if href.startswith(f'{model}/resolve/main'):
+            if href.endswith(('.json', '.txt')) or (href.endswith('.bin') and 'pytorch_model' in href):
+                downloads.append(f'https://huggingface.co/{href}')
+
+    # Downloading the files
+    print(f"Downloading the model to {output_folder}...")
+    pool = multiprocessing.Pool(processes=4)
+    results = pool.map(get_file, [[downloads[i], output_folder] for i in range(len(downloads))])
+    pool.close()
+    pool.join()