|
@@ -1,18 +1,19 @@
|
|
|
import re
|
|
import re
|
|
|
|
|
+import gc
|
|
|
import time
|
|
import time
|
|
|
import glob
|
|
import glob
|
|
|
-from sys import exit
|
|
|
|
|
import torch
|
|
import torch
|
|
|
import argparse
|
|
import argparse
|
|
|
import json
|
|
import json
|
|
|
|
|
+from sys import exit
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
import gradio as gr
|
|
import gradio as gr
|
|
|
-import transformers
|
|
|
|
|
-from html_generator import *
|
|
|
|
|
-from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
|
import warnings
|
|
import warnings
|
|
|
-import gc
|
|
|
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
|
|
|
+import transformers
|
|
|
|
|
+from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
|
+from modules.html_generator import *
|
|
|
|
|
+from modules.ui import *
|
|
|
|
|
|
|
|
|
|
|
|
|
transformers.logging.set_verbosity_error()
|
|
transformers.logging.set_verbosity_error()
|
|
@@ -36,9 +37,18 @@ parser.add_argument('--share', action='store_true', help='Create a public URL. T
|
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
loaded_preset = None
|
|
loaded_preset = None
|
|
|
-available_models = sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
|
|
|
|
|
-available_presets = sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
|
|
|
|
|
-available_characters = sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
|
|
|
|
|
|
|
+def get_available_models():
|
|
|
|
|
+ return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
|
|
|
|
|
+
|
|
|
|
|
+def get_available_presets():
|
|
|
|
|
+ return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
|
|
|
|
|
+
|
|
|
|
|
+def get_available_characters():
|
|
|
|
|
+ return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
|
|
|
|
|
+
|
|
|
|
|
+available_models = get_available_models()
|
|
|
|
|
+available_presets = get_available_presets()
|
|
|
|
|
+available_characters = get_available_characters()
|
|
|
|
|
|
|
|
settings = {
|
|
settings = {
|
|
|
'max_new_tokens': 200,
|
|
'max_new_tokens': 200,
|
|
@@ -227,7 +237,7 @@ else:
|
|
|
default_text = settings['prompt']
|
|
default_text = settings['prompt']
|
|
|
|
|
|
|
|
description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
|
|
description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
|
|
|
-css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem}"
|
|
|
|
|
|
|
+css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}"
|
|
|
if args.chat or args.cai_chat:
|
|
if args.chat or args.cai_chat:
|
|
|
history = []
|
|
history = []
|
|
|
character = None
|
|
character = None
|
|
@@ -413,24 +423,30 @@ if args.chat or args.cai_chat:
|
|
|
with gr.Row():
|
|
with gr.Row():
|
|
|
with gr.Column():
|
|
with gr.Column():
|
|
|
length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
|
|
length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
|
|
|
- model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
|
|
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
|
|
|
|
|
+ create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
|
|
|
with gr.Column():
|
|
with gr.Column():
|
|
|
history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size (0 for no limit)', value=settings['history_size'])
|
|
history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size (0 for no limit)', value=settings['history_size'])
|
|
|
- preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'], label='Settings preset')
|
|
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'], label='Settings preset')
|
|
|
|
|
+ create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
|
|
|
|
|
|
|
|
name1 = gr.Textbox(value=settings[f'name1{suffix}'], lines=1, label='Your name')
|
|
name1 = gr.Textbox(value=settings[f'name1{suffix}'], lines=1, label='Your name')
|
|
|
name2 = gr.Textbox(value=settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
|
|
name2 = gr.Textbox(value=settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
|
|
|
context = gr.Textbox(value=settings[f'context{suffix}'], lines=2, label='Context')
|
|
context = gr.Textbox(value=settings[f'context{suffix}'], lines=2, label='Context')
|
|
|
with gr.Row():
|
|
with gr.Row():
|
|
|
- character_menu = gr.Dropdown(choices=["None"]+available_characters, value="None", label='Character')
|
|
|
|
|
|
|
+ character_menu = gr.Dropdown(choices=available_characters, value="None", label='Character')
|
|
|
|
|
+ create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button")
|
|
|
|
|
+
|
|
|
with gr.Row():
|
|
with gr.Row():
|
|
|
check = gr.Checkbox(value=settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
|
|
check = gr.Checkbox(value=settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
|
|
|
with gr.Row():
|
|
with gr.Row():
|
|
|
with gr.Column():
|
|
with gr.Column():
|
|
|
- gr.Markdown("Upload chat history")
|
|
|
|
|
|
|
+ gr.Markdown("Upload chat history", elem_id="upload-label")
|
|
|
upload = gr.File(type='binary')
|
|
upload = gr.File(type='binary')
|
|
|
with gr.Column():
|
|
with gr.Column():
|
|
|
- gr.Markdown("Download chat history")
|
|
|
|
|
|
|
+ gr.Markdown("Download chat history", elem_id="download-label")
|
|
|
save_btn = gr.Button(value="Click me")
|
|
save_btn = gr.Button(value="Click me")
|
|
|
download = gr.File()
|
|
download = gr.File()
|
|
|
|
|
|
|
@@ -473,9 +489,13 @@ elif args.notebook:
|
|
|
length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
|
|
length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
|
|
|
with gr.Row():
|
|
with gr.Row():
|
|
|
with gr.Column():
|
|
with gr.Column():
|
|
|
- model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
|
|
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
|
|
|
|
|
+ create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
|
|
|
with gr.Column():
|
|
with gr.Column():
|
|
|
- preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
|
|
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
|
|
|
|
|
+ create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
|
|
|
|
|
|
|
|
gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")
|
|
gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")
|
|
|
gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream)
|
|
gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream)
|
|
@@ -488,8 +508,12 @@ else:
|
|
|
with gr.Column():
|
|
with gr.Column():
|
|
|
textbox = gr.Textbox(value=default_text, lines=15, label='Input')
|
|
textbox = gr.Textbox(value=default_text, lines=15, label='Input')
|
|
|
length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
|
|
length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
|
|
|
- preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
|
|
|
|
|
- model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
|
|
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
|
|
|
|
|
+ create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
|
|
|
|
|
+ create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
|
|
|
btn = gr.Button("Generate")
|
|
btn = gr.Button("Generate")
|
|
|
with gr.Row():
|
|
with gr.Row():
|
|
|
with gr.Column():
|
|
with gr.Column():
|