浏览代码

Add refresh buttons for the model/preset/character menus

oobabooga 3 年之前
父节点
当前提交
434d4b128c
共有 5 个文件被更改,包括 73 次插入18 次删除
  1. 1 0
      README.md
  2. 0 0
      modules/html_generator.py
  3. 30 0
      modules/ui.py
  4. 42 18
      server.py
  5. 0 0
      torch-dumps/place-your-pt-models-here.txt

+ 1 - 0
README.md

@@ -150,3 +150,4 @@ Pull requests, suggestions, and issue reports are welcome.
 - NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets
 - Pygmalion preset: https://github.com/PygmalionAI/gradio-ui/blob/master/src/gradio_ui.py
 - Verbose preset: Anonymous 4chan user.
+- Gradio dropdown menu refresh button: https://github.com/AUTOMATIC1111/stable-diffusion-webui

+ 0 - 0
html_generator.py → modules/html_generator.py


+ 30 - 0
modules/ui.py

@@ -0,0 +1,30 @@
+import gradio as gr
+
+refresh_symbol = '\U0001f504'  # 🔄
+
+class ToolButton(gr.Button, gr.components.FormComponent):
+    """Small button with single emoji as text, fits inside gradio forms"""
+
+    def __init__(self, **kwargs):
+        super().__init__(variant="tool", **kwargs)
+
+    def get_block_name(self):
+        return "button"
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+    def refresh():
+        refresh_method()
+        args = refreshed_args() if callable(refreshed_args) else refreshed_args
+
+        for k, v in args.items():
+            setattr(refresh_component, k, v)
+
+        return gr.update(**(args or {}))
+
+    refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+    refresh_button.click(
+        fn=refresh,
+        inputs=[],
+        outputs=[refresh_component]
+    )
+    return refresh_button

+ 42 - 18
server.py

@@ -1,18 +1,19 @@
 import re
+import gc
 import time
 import glob
-from sys import exit
 import torch
 import argparse
 import json
+from sys import exit
 from pathlib import Path
 import gradio as gr
-import transformers
-from html_generator import *
-from transformers import AutoTokenizer, AutoModelForCausalLM
 import warnings
-import gc
 from tqdm import tqdm
+import transformers
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from modules.html_generator import *
+from modules.ui import *
 
 
 transformers.logging.set_verbosity_error()
@@ -36,9 +37,18 @@ parser.add_argument('--share', action='store_true', help='Create a public URL. T
 args = parser.parse_args()
 
 loaded_preset = None
-available_models = sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
-available_presets = sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
-available_characters = sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
+def get_available_models():
+    return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
+
+def get_available_presets():
+    return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
+
+def get_available_characters():
+    return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
+
+available_models = get_available_models()
+available_presets = get_available_presets()
+available_characters = get_available_characters()
 
 settings = {
     'max_new_tokens': 200,
@@ -227,7 +237,7 @@ else:
     default_text = settings['prompt']
 
 description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
-css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem}"
+css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}"
 if args.chat or args.cai_chat:
     history = []
     character = None
@@ -413,24 +423,30 @@ if args.chat or args.cai_chat:
         with gr.Row():
             with gr.Column():
                 length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
-                model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
+                with gr.Row():
+                    model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
+                    create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
             with gr.Column():
                 history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size (0 for no limit)', value=settings['history_size'])
-                preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'], label='Settings preset')
+                with gr.Row():
+                    preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'], label='Settings preset')
+                    create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
 
         name1 = gr.Textbox(value=settings[f'name1{suffix}'], lines=1, label='Your name')
         name2 = gr.Textbox(value=settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
         context = gr.Textbox(value=settings[f'context{suffix}'], lines=2, label='Context')
         with gr.Row():
-            character_menu = gr.Dropdown(choices=["None"]+available_characters, value="None", label='Character')
+            character_menu = gr.Dropdown(choices=available_characters, value="None", label='Character')
+            create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button")
+
         with gr.Row():
             check = gr.Checkbox(value=settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
         with gr.Row():
             with gr.Column():
-                gr.Markdown("Upload chat history")
+                gr.Markdown("Upload chat history", elem_id="upload-label")
                 upload = gr.File(type='binary')
             with gr.Column():
-                gr.Markdown("Download chat history")
+                gr.Markdown("Download chat history", elem_id="download-label")
                 save_btn = gr.Button(value="Click me")
                 download = gr.File()
 
@@ -473,9 +489,13 @@ elif args.notebook:
         length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
         with gr.Row():
             with gr.Column():
-                model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
+                with gr.Row():
+                    model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
+                    create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
             with gr.Column():
-                preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
+                with gr.Row():
+                    preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
+                    create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
 
         gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")
         gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream)
@@ -488,8 +508,12 @@ else:
             with gr.Column():
                 textbox = gr.Textbox(value=default_text, lines=15, label='Input')
                 length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
-                preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
-                model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
+                with gr.Row():
+                    preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
+                    create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
+                with gr.Row():
+                    model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
+                    create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
                 btn = gr.Button("Generate")
                 with gr.Row():
                     with gr.Column():

+ 0 - 0
torch-dumps/place-your-pt-models-here.txt