Ver código fonte

Don't download if --check is specified

oobabooga 2 anos atrás
pai
commit
92c7068daf
1 arquivos alterados com 17 adições e 7 exclusões
  1. 17 7
      download-model.py

+ 17 - 7
download-model.py

@@ -9,6 +9,7 @@ python download-model.py facebook/opt-1.3b
 import argparse
 import base64
 import datetime
+import hashlib
 import json
 import re
 import sys
@@ -17,7 +18,6 @@ from pathlib import Path
 import requests
 import tqdm
 from tqdm.contrib.concurrent import thread_map
-import hashlib
 
 parser = argparse.ArgumentParser()
 parser.add_argument('MODEL', type=str, default=None, nargs='?')
@@ -212,22 +212,32 @@ if __name__ == '__main__':
         if sha256_str != '':
             f.write(f'sha256sum:\n{sha256_str}')
 
-    # Downloading the files
-    print(f"Downloading the model to {output_folder}")
-    download_files(links, output_folder, args.threads)
-    
     if args.check:
         # Validate the checksums
         validated = True
         for i in range(len(sha256)):
+            fpath = (output_folder / sha256[i][0])
+
+            if not fpath.exists():
+                print(f"The following file is missing: {fpath}")
+                validated = False
+                continue
+
             with open(output_folder / sha256[i][0], "rb") as f:
                 bytes = f.read()
                 file_hash = hashlib.sha256(bytes).hexdigest()
                 if file_hash != sha256[i][1]:
-                    print(f'[!] Checksum for {sha256[i][0]} failed!')
+                    print(f'Checksum failed: {sha256[i][0]}  {sha256[i][1]}')
                     validated = False
+                else:
+                    print(f'Checksum validated: {sha256[i][0]}  {sha256[i][1]}')
         
         if validated:
             print('[+] Validated checksums of all model files!')
         else:
-            print('[-] Rerun the download-model.py with --clean flag')
+            print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
+
+    else:
+        # Downloading the files
+        print(f"Downloading the model to {output_folder}")
+        download_files(links, output_folder, args.threads)