Parcourir la source

download custom model menu (from hugging face) added in model tab

Usama Kenway il y a 2 ans
Parent
commit
7436dd5b4a
1 fichiers modifiés avec 68 ajouts et 1 suppressions
  1. 68 1
      server.py

+ 68 - 1
server.py

@@ -10,7 +10,8 @@ import time
 import zipfile
 from datetime import datetime
 from pathlib import Path
-
+import os
+import requests
 import gradio as gr
 from PIL import Image
 
@@ -20,6 +21,7 @@ from modules.html_generator import chat_html_wrapper
 from modules.LoRA import add_lora_to_model
 from modules.models import load_model, load_soft_prompt, unload_model
 from modules.text_generation import generate_reply, stop_everything_event
+from huggingface_hub import HfApi
 
 # Loading custom settings
 settings_file = None
@@ -172,6 +174,62 @@ def create_prompt_menus():
     shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
 
 
+def download_model_wrapper(repo_id):
+    print(repo_id)
+    if repo_id == '':
+        print("Please enter a valid repo ID. This field cant be empty")
+    else:
+        try:
+            print('Downloading repo')
+            hf_api = HfApi()
+            # Get repo info
+            repo_info = hf_api.repo_info(
+                repo_id=repo_id,
+                repo_type="model",
+                revision="main"
+            )
+            # create model and repo folder and check for lora
+            is_lora = False
+            for file in repo_info.siblings:
+                if 'adapter_model.bin' in file.rfilename:
+                    is_lora = True
+            repo_dir_name = repo_id.replace("/", "--")
+            if is_lora is True:
+                models_dir = ".loras"
+            else:
+                models_dir = ".models"
+            if not os.path.exists(models_dir):
+                os.makedirs(models_dir)
+            repo_dir = os.path.join(models_dir, repo_dir_name)
+            if not os.path.exists(repo_dir):
+                os.makedirs(repo_dir)
+
+            for sibling in repo_info.siblings:
+                filename = sibling.rfilename
+                url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}"
+                download_path = os.path.join(repo_dir, filename)
+                response = requests.get(url, stream=True)
+                # Get the total file size from the content-length header
+                total_size = int(response.headers.get('content-length', 0))
+
+                # Download the file in chunks and print progress
+                with open(download_path, 'wb') as f:
+                    downloaded_size = 0
+                    for data in response.iter_content(chunk_size=10000000):
+                        downloaded_size += len(data)
+                        f.write(data)
+                        progress = downloaded_size * 100 // total_size
+                        downloaded_size_mb = downloaded_size / (1024 * 1024)
+                        total_size_mb = total_size / (1024 * 1024)
+                        print(f"\rDownloading {filename}... {progress}% complete "
+                              f"({downloaded_size_mb:.2f}/{total_size_mb:.2f} MB)", end="", flush=True)
+                    print(f"\rDownloading {filename}... Complete!")
+
+            print('Repo Downloaded')
+        except ValueError as e:
+            raise ValueError("Please enter a valid repo ID. Error: {}".format(e))
+
+
 def create_model_menus():
     with gr.Row():
         with gr.Column():
@@ -182,6 +240,15 @@ def create_model_menus():
             with gr.Row():
                 shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
                 ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
+    with gr.Row():
+        with gr.Column(scale=0.5):
+            shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model",
+                                                            info="Enter hugging face username/model path e.g: 'decapoda-research/llama-7b-hf'")
+    with gr.Row():
+        with gr.Column(scale=0.5):
+            shared.gradio['download_button'] = gr.Button("Download", show_progress=True)
+            shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'],
+                                                   show_progress=True)
 
     shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
     shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)