Ver código fonte

Add safetensors support

oobabooga 3 anos atrás
pai
commit
03f084f311
2 arquivos alterados com 4 adições e 1 exclusões
  1. 3 1
      download-model.py
  2. 1 0
      requirements.txt

+ 3 - 1
download-model.py

@@ -71,7 +71,9 @@ if __name__ == '__main__':
     for link in links:
         href = link.get('href')[1:]
         if href.startswith(f'{model}/resolve/{branch}'):
-            if href.endswith(('.json', '.txt')) or (href.endswith('.bin') and 'pytorch_model' in href):
+            is_pytorch = href.endswith('.bin') and 'pytorch_model' in href
+            is_safetensors = href.endswith('.safetensors') and 'model' in href
+            if href.endswith(('.json', '.txt')) or is_pytorch or is_safetensors:
                 downloads.append(f'https://huggingface.co/{href}')
 
     # Downloading the files

+ 1 - 0
requirements.txt

@@ -4,3 +4,4 @@ deepspeed==0.8.0
 gradio==3.15.0
 transformers==4.25.1
 beautifulsoup4
+safetensors