Jelajahi Sumber

Add support for resuming downloads (#654 from nikita-skakun/support-partial-downloads)

oobabooga 2 tahun lalu
induk
melakukan
23116b88ef
1 mengubah file dengan 63 tambahan dan 19 penghapusan
  1. 63 19
      download-model.py

+ 63 - 19
download-model.py

@@ -9,6 +9,7 @@ python download-model.py facebook/opt-1.3b
 import argparse
 import base64
 import datetime
+import hashlib
 import json
 import re
 import sys
@@ -24,11 +25,28 @@ parser.add_argument('--branch', type=str, default='main', help='Name of the Git
 parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
 parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
 parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
+parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
+parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
 args = parser.parse_args()
 
 def get_file(url, output_folder):
-    r = requests.get(url, stream=True)
-    with open(output_folder / Path(url.rsplit('/', 1)[1]), 'wb') as f:
+    filename = Path(url.rsplit('/', 1)[1])
+    output_path = output_folder / filename
+    if output_path.exists() and not args.clean:
+        # Check if the file has already been downloaded completely
+        r = requests.get(url, stream=True)
+        total_size = int(r.headers.get('content-length', 0))
+        if output_path.stat().st_size >= total_size:
+            return
+        # Otherwise, resume the download from where it left off
+        headers = {'Range': f'bytes={output_path.stat().st_size}-'}
+        mode = 'ab'
+    else:
+        headers = {}
+        mode = 'wb'
+
+    r = requests.get(url, stream=True, headers=headers)
+    with open(output_path, mode) as f:
         total_size = int(r.headers.get('content-length', 0))
         block_size = 1024
         with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
@@ -154,7 +172,7 @@ def get_download_links_from_huggingface(model, branch):
     return links, sha256, is_lora
 
 def download_files(file_list, output_folder, num_threads=8):
-    thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads)
+    thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
 
 if __name__ == '__main__':
     model = args.MODEL
@@ -184,22 +202,48 @@ if __name__ == '__main__':
     output_folder = f"{'_'.join(model.split('/')[-2:])}"
     if branch != 'main':
         output_folder += f'_{branch}'
-
-    # Creating the folder and writing the metadata
     output_folder = Path(base_folder) / output_folder
-    if not output_folder.exists():
-        output_folder.mkdir()
-    with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
-        f.write(f'url: https://huggingface.co/{model}\n')
-        f.write(f'branch: {branch}\n')
-        f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
-        sha256_str = ''
+
+    if args.check:
+        # Validate the checksums
+        validated = True
         for i in range(len(sha256)):
-            sha256_str += f'    {sha256[i][1]} {sha256[i][0]}\n'
-        if sha256_str != '':
-            f.write(f'sha256sum:\n{sha256_str}')
+            fpath = (output_folder / sha256[i][0])
+
+            if not fpath.exists():
+                print(f"The following file is missing: {fpath}")
+                validated = False
+                continue
+
+            with open(output_folder / sha256[i][0], "rb") as f:
+                bytes = f.read()
+                file_hash = hashlib.sha256(bytes).hexdigest()
+                if file_hash != sha256[i][1]:
+                    print(f'Checksum failed: {sha256[i][0]}  {sha256[i][1]}')
+                    validated = False
+                else:
+                    print(f'Checksum validated: {sha256[i][0]}  {sha256[i][1]}')
+        
+        if validated:
+            print('[+] Validated checksums of all model files!')
+        else:
+            print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
 
-    # Downloading the files
-    print(f"Downloading the model to {output_folder}")
-    download_files(links, output_folder, args.threads)
-    print()
+    else:
+
+        # Creating the folder and writing the metadata
+        if not output_folder.exists():
+            output_folder.mkdir()
+        with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
+            f.write(f'url: https://huggingface.co/{model}\n')
+            f.write(f'branch: {branch}\n')
+            f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
+            sha256_str = ''
+            for i in range(len(sha256)):
+                sha256_str += f'    {sha256[i][1]} {sha256[i][0]}\n'
+            if sha256_str != '':
+                f.write(f'sha256sum:\n{sha256_str}')
+
+        # Downloading the files
+        print(f"Downloading the model to {output_folder}")
+        download_files(links, output_folder, args.threads)