Переглянути джерело

Use download-model.py to download the model

oobabooga 2 роки тому
батько
коміт
2c14df81a8
2 змінених файлів з 53 додано та 75 видалено
  1. 11 11
      download-model.py
  2. 42 64
      server.py

+ 11 - 11
download-model.py

@@ -20,17 +20,6 @@ import tqdm
 from tqdm.contrib.concurrent import thread_map
 
 
-parser = argparse.ArgumentParser()
-parser.add_argument('MODEL', type=str, default=None, nargs='?')
-parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
-parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
-parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
-parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
-parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
-parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
-args = parser.parse_args()
-
-
 def select_model_from_default_options():
     models = {
         "OPT 6.7B": ("facebook", "opt-6.7b", "main"),
@@ -244,6 +233,17 @@ def check_model_files(model, branch, links, sha256, output_folder):
 
 
 if __name__ == '__main__':
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('MODEL', type=str, default=None, nargs='?')
+    parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
+    parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
+    parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
+    parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
+    parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
+    parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
+    args = parser.parse_args()
+
     branch = args.branch
     model = args.MODEL
     if model is None:

+ 42 - 64
server.py

@@ -2,17 +2,21 @@ import os
 
 os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
 
+import importlib
 import io
 import json
+import os
 import re
 import sys
 import time
+import traceback
 import zipfile
 from datetime import datetime
 from pathlib import Path
-import os
-import requests
+
 import gradio as gr
+import requests
+from huggingface_hub import HfApi
 from PIL import Image
 
 import modules.extensions as extensions_module
@@ -21,7 +25,6 @@ from modules.html_generator import chat_html_wrapper
 from modules.LoRA import add_lora_to_model
 from modules.models import load_model, load_soft_prompt, unload_model
 from modules.text_generation import generate_reply, stop_everything_event
-from huggingface_hub import HfApi
 
 # Loading custom settings
 settings_file = None
@@ -175,59 +178,31 @@ def create_prompt_menus():
 
 
 def download_model_wrapper(repo_id):
-    print(repo_id)
-    if repo_id == '':
-        print("Please enter a valid repo ID. This field cant be empty")
-    else:
-        try:
-            print('Downloading repo')
-            hf_api = HfApi()
-            # Get repo info
-            repo_info = hf_api.repo_info(
-                repo_id=repo_id,
-                repo_type="model",
-                revision="main"
-            )
-            # create model and repo folder and check for lora
-            is_lora = False
-            for file in repo_info.siblings:
-                if 'adapter_model.bin' in file.rfilename:
-                    is_lora = True
-            repo_dir_name = repo_id.replace("/", "--")
-            if is_lora is True:
-                models_dir = "loras"
-            else:
-                models_dir = "models"
-            if not os.path.exists(models_dir):
-                os.makedirs(models_dir)
-            repo_dir = os.path.join(models_dir, repo_dir_name)
-            if not os.path.exists(repo_dir):
-                os.makedirs(repo_dir)
-
-            for sibling in repo_info.siblings:
-                filename = sibling.rfilename
-                url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}"
-                download_path = os.path.join(repo_dir, filename)
-                response = requests.get(url, stream=True)
-                # Get the total file size from the content-length header
-                total_size = int(response.headers.get('content-length', 0))
-
-                # Download the file in chunks and print progress
-                with open(download_path, 'wb') as f:
-                    downloaded_size = 0
-                    for data in response.iter_content(chunk_size=10000000):
-                        downloaded_size += len(data)
-                        f.write(data)
-                        progress = downloaded_size * 100 // total_size
-                        downloaded_size_mb = downloaded_size / (1024 * 1024)
-                        total_size_mb = total_size / (1024 * 1024)
-                        print(f"\rDownloading {filename}... {progress}% complete "
-                              f"({downloaded_size_mb:.2f}/{total_size_mb:.2f} MB)", end="", flush=True)
-                    print(f"\rDownloading {filename}... Complete!")
-
-            print('Repo Downloaded')
-        except ValueError as e:
-            raise ValueError("Please enter a valid repo ID. Error: {}".format(e))
+    try:
+        downloader = importlib.import_module("download-model")
+
+        model = repo_id
+        branch = "main"
+        check = False
+
+        yield("Cleaning up the model/branch names")
+        model, branch = downloader.sanitize_model_and_branch_names(model, branch)
+
+        yield("Getting the download links from Hugging Face")
+        links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
+
+        yield("Getting the output folder")
+        output_folder = downloader.get_output_folder(model, branch, is_lora)
+
+        if check:
+            yield("Checking previously downloaded files")
+            downloader.check_model_files(model, branch, links, sha256, output_folder)
+        else:
+            yield("Downloading files")
+            downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
+            yield("Done!")
+    except:
+        yield traceback.format_exc()
 
 
 def create_model_menus():
@@ -241,17 +216,20 @@ def create_model_menus():
                 shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
                 ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
     with gr.Row():
-        with gr.Column(scale=0.5):
-            shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model",
-                                                            info="Enter hugging face username/model path e.g: 'decapoda-research/llama-7b-hf'")
-    with gr.Row():
-        with gr.Column(scale=0.5):
-            shared.gradio['download_button'] = gr.Button("Download", show_progress=True)
-            shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'],
-                                                   show_progress=True)
+        with gr.Column():
+            with gr.Row():
+                with gr.Column():
+                    shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model",
+                                                                    info="Enter Hugging Face username/model path e.g: facebook/galactica-125m")
+                with gr.Column():
+                    shared.gradio['download_button'] = gr.Button("Download", show_progress=True)
+                    shared.gradio['download_status'] = gr.Markdown()
+        with gr.Column():
+            pass
 
     shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
     shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
+    shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['download_status'], show_progress=False)
 
 
 def create_settings_menus(default_preset):