Просмотр исходного кода

Only download safetensors if both pytorch and safetensors are present

oobabooga 3 лет назад
Родитель
Сommit
66862203fc
1 измененных файлов с 23 добавлено и 3 удалено
  1. 23 3
      download-model.py

+ 23 - 3
download-model.py

@@ -72,13 +72,33 @@ if __name__ == '__main__':
     soup = BeautifulSoup(page.content, 'html.parser') 
     links = soup.find_all('a')
     downloads = []
+    classifications = []
+    has_pytorch = False
+    has_safetensors = False
     for link in links:
         href = link.get('href')[1:]
         if href.startswith(f'{model}/resolve/{branch}'):
-            is_pytorch = href.endswith('.bin') and 'pytorch_model' in href
-            is_safetensors = href.endswith('.safetensors') and 'model' in href
-            if href.endswith(('.json', '.txt')) or is_pytorch or is_safetensors:
+            fname = Path(href).name
+            is_pytorch = re.match("pytorch_model.*\.bin", fname)
+            is_safetensors = re.match("model.*\.safetensors", fname)
+            is_text = re.match(".*\.(txt|json)", fname)
+
+            if is_text or is_safetensors or is_pytorch:
                 downloads.append(f'https://huggingface.co/{href}')
+                if is_text:
+                    classifications.append('text')
+                elif is_safetensors:
+                    has_safetensors = True
+                    classifications.append('safetensors')
+                elif is_pytorch:
+                    has_pytorch = True
+                    classifications.append('pytorch')
+
+    # If both pytorch and safetensors are available, download safetensors only
+    if has_pytorch and has_safetensors:
+        for i in range(len(classifications)-1, -1, -1):
+            if classifications[i] == 'pytorch':
+                downloads.pop(i)
 
     # Downloading the files
     print(f"Downloading the model to {output_folder}")