Просмотр исходного кода

Make download-model.py importable

oobabooga 2 лет назад
Родитель
Сommit
34ec02d41d
1 измененных файлов с 116 добавлено и 102 удалено
  1. 116 102
      download-model.py

+ 116 - 102
download-model.py

@@ -2,7 +2,7 @@
 Downloads models from Hugging Face to models/model-name.
 
 Example:
-python download-model.py facebook/opt-1.3b
+python download_model.py facebook/opt-1.3b
 
 '''
 
@@ -19,6 +19,7 @@ import requests
 import tqdm
 from tqdm.contrib.concurrent import thread_map
 
+
 parser = argparse.ArgumentParser()
 parser.add_argument('MODEL', type=str, default=None, nargs='?')
 parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
@@ -30,40 +31,6 @@ parser.add_argument('--check', action='store_true', help='Validates the checksum
 args = parser.parse_args()
 
 
-def get_file(url, output_folder):
-    filename = Path(url.rsplit('/', 1)[1])
-    output_path = output_folder / filename
-    if output_path.exists() and not args.clean:
-        # Check if the file has already been downloaded completely
-        r = requests.get(url, stream=True)
-        total_size = int(r.headers.get('content-length', 0))
-        if output_path.stat().st_size >= total_size:
-            return
-        # Otherwise, resume the download from where it left off
-        headers = {'Range': f'bytes={output_path.stat().st_size}-'}
-        mode = 'ab'
-    else:
-        headers = {}
-        mode = 'wb'
-
-    r = requests.get(url, stream=True, headers=headers)
-    with open(output_path, mode) as f:
-        total_size = int(r.headers.get('content-length', 0))
-        block_size = 1024
-        with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
-            for data in r.iter_content(block_size):
-                t.update(len(data))
-                f.write(data)
-
-
-def sanitize_branch_name(branch_name):
-    pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
-    if pattern.match(branch_name):
-        return branch_name
-    else:
-        raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
-
-
 def select_model_from_default_options():
     models = {
         "OPT 6.7B": ("facebook", "opt-6.7b", "main"),
@@ -110,7 +77,20 @@ EleutherAI/pythia-1.4b-deduped
     return model, branch
 
 
-def get_download_links_from_huggingface(model, branch):
+def sanitize_model_and_branch_names(model, branch):
+    if model[-1] == '/':
+        model = model[:-1]
+    if branch is None:
+        branch = "main"
+    else:
+        pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
+        if not pattern.match(branch):
+            raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
+
+    return model, branch
+
+
+def get_download_links_from_huggingface(model, branch, text_only=False):
     base = "https://huggingface.co"
     page = f"/api/models/{model}/tree/{branch}?cursor="
     cursor = b""
@@ -149,7 +129,7 @@ def get_download_links_from_huggingface(model, branch):
                     links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
                     classifications.append('text')
                     continue
-                if not args.text_only:
+                if not text_only:
                     links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
                     if is_safetensors:
                         has_safetensors = True
@@ -177,80 +157,114 @@ def get_download_links_from_huggingface(model, branch):
     return links, sha256, is_lora
 
 
-def download_files(file_list, output_folder, num_threads=8):
-    thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
-
-
-if __name__ == '__main__':
-    model = args.MODEL
-    branch = args.branch
-    if model is None:
-        model, branch = select_model_from_default_options()
-    else:
-        if model[-1] == '/':
-            model = model[:-1]
-            branch = args.branch
-        if branch is None:
-            branch = "main"
-        else:
-            try:
-                branch = sanitize_branch_name(branch)
-            except ValueError as err_branch:
-                print(f"Error: {err_branch}")
-                sys.exit()
-
-    links, sha256, is_lora = get_download_links_from_huggingface(model, branch)
-
-    if args.output is not None:
-        base_folder = args.output
-    else:
+def get_output_folder(model, branch, is_lora, base_folder=None):
+    if base_folder is None:
         base_folder = 'models' if not is_lora else 'loras'
 
     output_folder = f"{'_'.join(model.split('/')[-2:])}"
     if branch != 'main':
         output_folder += f'_{branch}'
     output_folder = Path(base_folder) / output_folder
+    return output_folder
+
+
+def get_single_file(url, output_folder, start_from_scratch=False):
+    filename = Path(url.rsplit('/', 1)[1])
+    output_path = output_folder / filename
+    if output_path.exists() and not start_from_scratch:
+        # Check if the file has already been downloaded completely
+        r = requests.get(url, stream=True)
+        total_size = int(r.headers.get('content-length', 0))
+        if output_path.stat().st_size >= total_size:
+            return
+        # Otherwise, resume the download from where it left off
+        headers = {'Range': f'bytes={output_path.stat().st_size}-'}
+        mode = 'ab'
+    else:
+        headers = {}
+        mode = 'wb'
+
+    r = requests.get(url, stream=True, headers=headers)
+    with open(output_path, mode) as f:
+        total_size = int(r.headers.get('content-length', 0))
+        block_size = 1024
+        with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
+            for data in r.iter_content(block_size):
+                t.update(len(data))
+                f.write(data)
 
-    if args.check:
-        # Validate the checksums
-        validated = True
-        for i in range(len(sha256)):
-            fpath = (output_folder / sha256[i][0])
 
-            if not fpath.exists():
-                print(f"The following file is missing: {fpath}")
+def start_download_threads(file_list, output_folder, start_from_scratch=False, threads=1):
+    thread_map(lambda url: get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
+
+
+def download_model_files(model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1):
+    # Creating the folder and writing the metadata
+    if not output_folder.exists():
+        output_folder.mkdir()
+    with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
+        f.write(f'url: https://huggingface.co/{model}\n')
+        f.write(f'branch: {branch}\n')
+        f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
+        sha256_str = ''
+        for i in range(len(sha256)):
+            sha256_str += f'    {sha256[i][1]} {sha256[i][0]}\n'
+        if sha256_str != '':
+            f.write(f'sha256sum:\n{sha256_str}')
+
+    # Downloading the files
+    print(f"Downloading the model to {output_folder}")
+    start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
+
+
+def check_model_files(model, branch, links, sha256, output_folder):
+    # Validate the checksums
+    validated = True
+    for i in range(len(sha256)):
+        fpath = (output_folder / sha256[i][0])
+
+        if not fpath.exists():
+            print(f"The following file is missing: {fpath}")
+            validated = False
+            continue
+
+        with open(output_folder / sha256[i][0], "rb") as f:
+            bytes = f.read()
+            file_hash = hashlib.sha256(bytes).hexdigest()
+            if file_hash != sha256[i][1]:
+                print(f'Checksum failed: {sha256[i][0]}  {sha256[i][1]}')
                 validated = False
-                continue
-
-            with open(output_folder / sha256[i][0], "rb") as f:
-                bytes = f.read()
-                file_hash = hashlib.sha256(bytes).hexdigest()
-                if file_hash != sha256[i][1]:
-                    print(f'Checksum failed: {sha256[i][0]}  {sha256[i][1]}')
-                    validated = False
-                else:
-                    print(f'Checksum validated: {sha256[i][0]}  {sha256[i][1]}')
-
-        if validated:
-            print('[+] Validated checksums of all model files!')
-        else:
-            print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
+            else:
+                print(f'Checksum validated: {sha256[i][0]}  {sha256[i][1]}')
 
+    if validated:
+        print('[+] Validated checksums of all model files!')
     else:
+        print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
 
-        # Creating the folder and writing the metadata
-        if not output_folder.exists():
-            output_folder.mkdir()
-        with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
-            f.write(f'url: https://huggingface.co/{model}\n')
-            f.write(f'branch: {branch}\n')
-            f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
-            sha256_str = ''
-            for i in range(len(sha256)):
-                sha256_str += f'    {sha256[i][1]} {sha256[i][0]}\n'
-            if sha256_str != '':
-                f.write(f'sha256sum:\n{sha256_str}')
-
-        # Downloading the files
-        print(f"Downloading the model to {output_folder}")
-        download_files(links, output_folder, args.threads)
+
+if __name__ == '__main__':
+    branch = args.branch
+    model = args.MODEL
+    if model is None:
+        model, branch = select_model_from_default_options()
+
+    # Cleaning up the model/branch names
+    try:
+        model, branch = sanitize_model_and_branch_names(model, branch)
+    except ValueError as err_branch:
+        print(f"Error: {err_branch}")
+        sys.exit()
+
+    # Getting the download links from Hugging Face
+    links, sha256, is_lora = get_download_links_from_huggingface(model, branch, text_only=args.text_only)
+
+    # Getting the output folder
+    output_folder = get_output_folder(model, branch, is_lora, base_folder=args.output)
+
+    if args.check:
+        # Check previously downloaded files
+        check_model_files(model, branch, links, sha256, output_folder)
+    else:
+        # Download files
+        download_model_files(model, branch, links, sha256, output_folder, threads=args.threads)