Reformat everything
This commit is contained in:
@@ -9,6 +9,7 @@ params = {
|
||||
'port': 5000,
|
||||
}
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if self.path == '/api/v1/model':
|
||||
@@ -32,7 +33,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
self.end_headers()
|
||||
|
||||
prompt = body['prompt']
|
||||
prompt_lines = [l.strip() for l in prompt.split('\n')]
|
||||
prompt_lines = [k.strip() for k in prompt.split('\n')]
|
||||
|
||||
max_context = body.get('max_context_length', 2048)
|
||||
|
||||
@@ -40,18 +41,18 @@ class Handler(BaseHTTPRequestHandler):
|
||||
prompt_lines.pop(0)
|
||||
|
||||
prompt = '\n'.join(prompt_lines)
|
||||
generate_params = {
|
||||
'max_new_tokens': int(body.get('max_length', 200)),
|
||||
generate_params = {
|
||||
'max_new_tokens': int(body.get('max_length', 200)),
|
||||
'do_sample': bool(body.get('do_sample', True)),
|
||||
'temperature': float(body.get('temperature', 0.5)),
|
||||
'top_p': float(body.get('top_p', 1)),
|
||||
'typical_p': float(body.get('typical', 1)),
|
||||
'repetition_penalty': float(body.get('rep_pen', 1.1)),
|
||||
'temperature': float(body.get('temperature', 0.5)),
|
||||
'top_p': float(body.get('top_p', 1)),
|
||||
'typical_p': float(body.get('typical', 1)),
|
||||
'repetition_penalty': float(body.get('rep_pen', 1.1)),
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': int(body.get('top_k', 0)),
|
||||
'top_k': int(body.get('top_k', 0)),
|
||||
'min_length': int(body.get('min_length', 0)),
|
||||
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)),
|
||||
'num_beams': int(body.get('num_beams',1)),
|
||||
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
|
||||
'num_beams': int(body.get('num_beams', 1)),
|
||||
'penalty_alpha': float(body.get('penalty_alpha', 0)),
|
||||
'length_penalty': float(body.get('length_penalty', 1)),
|
||||
'early_stopping': bool(body.get('early_stopping', False)),
|
||||
@@ -59,7 +60,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
}
|
||||
|
||||
generator = generate_reply(
|
||||
prompt,
|
||||
prompt,
|
||||
generate_params,
|
||||
stopping_strings=body.get('stopping_strings', []),
|
||||
)
|
||||
@@ -84,9 +85,9 @@ class Handler(BaseHTTPRequestHandler):
|
||||
def run_server():
|
||||
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
|
||||
server = ThreadingHTTPServer(server_addr, Handler)
|
||||
if shared.args.share:
|
||||
if shared.args.share:
|
||||
try:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
public_url = _run_cloudflared(params['port'], params['port'] + 1)
|
||||
print(f'Starting KoboldAI compatible api at {public_url}/api')
|
||||
except ImportError:
|
||||
@@ -95,5 +96,6 @@ def run_server():
|
||||
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
def setup():
|
||||
Thread(target=run_server, daemon=True).start()
|
||||
|
||||
@@ -5,14 +5,16 @@ params = {
|
||||
"bias string": " *I am so happy*",
|
||||
}
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
they are fed into the model.
|
||||
"""
|
||||
"""
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
@@ -20,6 +22,7 @@ def output_modifier(string):
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def bot_prefix_modifier(string):
|
||||
"""
|
||||
This function is only applied in chat mode. It modifies
|
||||
@@ -27,11 +30,12 @@ def bot_prefix_modifier(string):
|
||||
behavior.
|
||||
"""
|
||||
|
||||
if params['activate'] == True:
|
||||
if params['activate']:
|
||||
return f'{string} {params["bias string"].strip()} '
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def ui():
|
||||
# Gradio elements
|
||||
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
|
||||
|
||||
@@ -20,16 +20,18 @@ user_info = None
|
||||
if not shared.args.no_stream:
|
||||
print("Please add --no-stream. This extension is not meant to be used with streaming.")
|
||||
raise ValueError
|
||||
|
||||
|
||||
# Check if the API is valid and refresh the UI accordingly.
|
||||
|
||||
|
||||
def check_valid_api():
|
||||
|
||||
|
||||
global user, user_info, params
|
||||
|
||||
user = ElevenLabsUser(params['api_key'])
|
||||
user_info = user._get_subscription_data()
|
||||
print('checking api')
|
||||
if params['activate'] == False:
|
||||
if not params['activate']:
|
||||
return gr.update(value='Disconnected')
|
||||
elif user_info is None:
|
||||
print('Incorrect API Key')
|
||||
@@ -37,24 +39,28 @@ def check_valid_api():
|
||||
else:
|
||||
print('Got an API Key!')
|
||||
return gr.update(value='Connected')
|
||||
|
||||
|
||||
# Once the API is verified, get the available voices and update the dropdown list
|
||||
|
||||
|
||||
def refresh_voices():
|
||||
|
||||
|
||||
global user, user_info
|
||||
|
||||
|
||||
your_voices = [None]
|
||||
if user_info is not None:
|
||||
for voice in user.get_available_voices():
|
||||
your_voices.append(voice.initialName)
|
||||
return gr.Dropdown.update(choices=your_voices)
|
||||
return gr.Dropdown.update(choices=your_voices)
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
@@ -64,16 +70,17 @@ def input_modifier(string):
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
"""
|
||||
|
||||
global params, wav_idx, user, user_info
|
||||
|
||||
if params['activate'] == False:
|
||||
|
||||
if not params['activate']:
|
||||
return string
|
||||
elif user_info == None:
|
||||
elif user_info is None:
|
||||
return string
|
||||
|
||||
string = remove_surrounded_chars(string)
|
||||
@@ -84,7 +91,7 @@ def output_modifier(string):
|
||||
|
||||
if string == '':
|
||||
string = 'empty reply, try regenerating'
|
||||
|
||||
|
||||
output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.wav'.format(wav_idx))
|
||||
voice = user.get_voices_by_name(params['selected_voice'])[0]
|
||||
audio_data = voice.generate_audio_bytes(string)
|
||||
@@ -94,6 +101,7 @@ def output_modifier(string):
|
||||
wav_idx += 1
|
||||
return string
|
||||
|
||||
|
||||
def ui():
|
||||
|
||||
# Gradio elements
|
||||
@@ -110,4 +118,4 @@ def ui():
|
||||
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
|
||||
api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
|
||||
connect.click(check_valid_api, [], connection_status)
|
||||
connect.click(refresh_voices, [], voice)
|
||||
connect.click(refresh_voices, [], voice)
|
||||
|
||||
@@ -85,7 +85,7 @@ def select_character(evt: gr.SelectData):
|
||||
def ui():
|
||||
with gr.Accordion("Character gallery", open=False):
|
||||
update = gr.Button("Refresh")
|
||||
gr.HTML(value="<style>"+generate_css()+"</style>")
|
||||
gr.HTML(value="<style>" + generate_css() + "</style>")
|
||||
gallery = gr.Dataset(components=[gr.HTML(visible=False)],
|
||||
label="",
|
||||
samples=generate_html(),
|
||||
@@ -93,4 +93,4 @@ def ui():
|
||||
samples_per_page=50
|
||||
)
|
||||
update.click(generate_html, [], gallery)
|
||||
gallery.select(select_character, None, gradio['character_menu'])
|
||||
gallery.select(select_character, None, gradio['character_menu'])
|
||||
|
||||
@@ -7,14 +7,16 @@ params = {
|
||||
|
||||
language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
they are fed into the model.
|
||||
"""
|
||||
"""
|
||||
|
||||
return GoogleTranslator(source=params['language string'], target='en').translate(string)
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
@@ -22,6 +24,7 @@ def output_modifier(string):
|
||||
|
||||
return GoogleTranslator(source='en', target=params['language string']).translate(string)
|
||||
|
||||
|
||||
def bot_prefix_modifier(string):
|
||||
"""
|
||||
This function is only applied in chat mode. It modifies
|
||||
@@ -31,6 +34,7 @@ def bot_prefix_modifier(string):
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def ui():
|
||||
# Finding the language name from the language code to use as the default value
|
||||
language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]
|
||||
|
||||
@@ -4,12 +4,14 @@ import pandas as pd
|
||||
|
||||
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
|
||||
|
||||
|
||||
def get_prompt_by_name(name):
|
||||
if name == 'None':
|
||||
return ''
|
||||
else:
|
||||
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
|
||||
|
||||
|
||||
def ui():
|
||||
if not shared.is_chat():
|
||||
choices = ['None'] + list(df['Prompt name'])
|
||||
|
||||
@@ -12,30 +12,33 @@ from PIL import Image
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
# parameters which can be customized in settings.json of webui
|
||||
# parameters which can be customized in settings.json of webui
|
||||
params = {
|
||||
'enable_SD_api': False,
|
||||
'address': 'http://127.0.0.1:7860',
|
||||
'save_img': False,
|
||||
'SD_model': 'NeverEndingDream', # not really used right now
|
||||
'SD_model': 'NeverEndingDream', # not really used right now
|
||||
'prompt_prefix': '(Masterpiece:1.1), (solo:1.3), detailed, intricate, colorful',
|
||||
'negative_prompt': '(worst quality, low quality:1.3)',
|
||||
'side_length': 512,
|
||||
'restore_faces': False
|
||||
}
|
||||
|
||||
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
|
||||
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
|
||||
|
||||
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
||||
picture_response = False # specifies if the next model response should appear as a picture
|
||||
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
||||
picture_response = False # specifies if the next model response should appear as a picture
|
||||
pic_id = 0
|
||||
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
||||
|
||||
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
@@ -51,7 +54,7 @@ def input_modifier(string):
|
||||
lowstr = string.lower()
|
||||
|
||||
# TODO: refactor out to separate handler and also replace detection with a regexp
|
||||
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
|
||||
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
|
||||
picture_response = True
|
||||
shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
|
||||
shared.processing_message = "*Is sending a picture...*"
|
||||
@@ -62,6 +65,8 @@ def input_modifier(string):
|
||||
return string
|
||||
|
||||
# Get and save the Stable Diffusion-generated picture
|
||||
|
||||
|
||||
def get_SD_pictures(description):
|
||||
|
||||
global params, pic_id
|
||||
@@ -77,13 +82,13 @@ def get_SD_pictures(description):
|
||||
"restore_faces": params['restore_faces'],
|
||||
"negative_prompt": params['negative_prompt']
|
||||
}
|
||||
|
||||
|
||||
response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
|
||||
r = response.json()
|
||||
|
||||
visible_result = ""
|
||||
for img_str in r['images']:
|
||||
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",",1)[0])))
|
||||
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
|
||||
if params['save_img']:
|
||||
output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
|
||||
image.save(output_file.as_posix())
|
||||
@@ -96,11 +101,13 @@ def get_SD_pictures(description):
|
||||
image_bytes = buffered.getvalue()
|
||||
img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
|
||||
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
|
||||
|
||||
|
||||
return visible_result
|
||||
|
||||
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
|
||||
# and replace it with 'text' for the purposes of logging?
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
@@ -130,6 +137,7 @@ def output_modifier(string):
|
||||
shared.args.no_stream = streaming_state
|
||||
return image + "\n" + text
|
||||
|
||||
|
||||
def bot_prefix_modifier(string):
|
||||
"""
|
||||
This function is only applied in chat mode. It modifies
|
||||
@@ -139,10 +147,12 @@ def bot_prefix_modifier(string):
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def force_pic():
|
||||
global picture_response
|
||||
picture_response = True
|
||||
|
||||
|
||||
def ui():
|
||||
|
||||
# Gradio elements
|
||||
@@ -153,7 +163,7 @@ def ui():
|
||||
save_img = gr.Checkbox(value=params['save_img'], label='Keep original received images in the outputs subdir')
|
||||
with gr.Column():
|
||||
address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Stable Diffusion host address')
|
||||
|
||||
|
||||
with gr.Row():
|
||||
force_btn = gr.Button("Force the next response to be a picture")
|
||||
generate_now_btn = gr.Button("Generate an image response to the input")
|
||||
@@ -162,9 +172,9 @@ def ui():
|
||||
prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
|
||||
with gr.Row():
|
||||
negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
|
||||
dimensions = gr.Slider(256,702,value=params['side_length'],step=64,label='Image dimensions')
|
||||
dimensions = gr.Slider(256, 702, value=params['side_length'], step=64, label='Image dimensions')
|
||||
# model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model')
|
||||
|
||||
|
||||
# Event functions to update the parameters in the backend
|
||||
enable.change(lambda x: params.update({"enable_SD_api": x}), enable, None)
|
||||
save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
|
||||
@@ -176,4 +186,4 @@ def ui():
|
||||
|
||||
force_btn.click(force_pic)
|
||||
generate_now_btn.click(force_pic)
|
||||
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
|
||||
@@ -17,11 +17,13 @@ input_hijack = {
|
||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
|
||||
|
||||
|
||||
def caption_image(raw_image):
|
||||
inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
|
||||
out = model.generate(**inputs, max_new_tokens=100)
|
||||
return processor.decode(out[0], skip_special_tokens=True)
|
||||
|
||||
|
||||
def generate_chat_picture(picture, name1, name2):
|
||||
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
|
||||
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
||||
@@ -32,6 +34,7 @@ def generate_chat_picture(picture, name1, name2):
|
||||
visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
|
||||
return text, visible_text
|
||||
|
||||
|
||||
def ui():
|
||||
picture_select = gr.Image(label='Send a picture', type='pil')
|
||||
|
||||
@@ -42,4 +45,4 @@ def ui():
|
||||
picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
|
||||
# Clear the picture from the upload field
|
||||
picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
|
||||
picture_select.upload(lambda: None, [], [picture_select], show_progress=False)
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import modules.chat as chat
|
||||
import modules.shared as shared
|
||||
import torch
|
||||
from extensions.silero_tts import tts_preprocessor
|
||||
from modules import chat, shared
|
||||
from modules.html_generator import chat_html_wrapper
|
||||
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
|
||||
params = {
|
||||
'activate': True,
|
||||
'speaker': 'en_56',
|
||||
@@ -26,7 +28,7 @@ current_params = params.copy()
|
||||
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
|
||||
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
|
||||
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
|
||||
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
||||
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
||||
|
||||
# Used for making text xml compatible, needed for voice pitch and speed control
|
||||
table = str.maketrans({
|
||||
@@ -37,26 +39,25 @@ table = str.maketrans({
|
||||
'"': """,
|
||||
})
|
||||
|
||||
|
||||
def xmlesc(txt):
|
||||
return txt.translate(table)
|
||||
|
||||
|
||||
def load_model():
|
||||
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
||||
model.to(params['device'])
|
||||
return model
|
||||
model = load_model()
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||
|
||||
def remove_tts_from_history(name1, name2):
|
||||
def remove_tts_from_history(name1, name2, mode):
|
||||
for i, entry in enumerate(shared.history['internal']):
|
||||
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
|
||||
return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
def toggle_text_in_history(name1, name2):
|
||||
|
||||
def toggle_text_in_history(name1, name2, mode):
|
||||
for i, entry in enumerate(shared.history['visible']):
|
||||
visible_reply = entry[1]
|
||||
if visible_reply.startswith('<audio'):
|
||||
@@ -65,7 +66,8 @@ def toggle_text_in_history(name1, name2):
|
||||
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
|
||||
else:
|
||||
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
|
||||
return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
@@ -75,12 +77,13 @@ def input_modifier(string):
|
||||
|
||||
# Remove autoplay from the last reply
|
||||
if shared.is_chat() and len(shared.history['internal']) > 0:
|
||||
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')]
|
||||
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>', 'controls>')]
|
||||
|
||||
shared.processing_message = "*Is recording a voice message...*"
|
||||
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
|
||||
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
|
||||
return string
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
@@ -98,11 +101,7 @@ def output_modifier(string):
|
||||
return string
|
||||
|
||||
original_string = string
|
||||
string = remove_surrounded_chars(string)
|
||||
string = string.replace('"', '')
|
||||
string = string.replace('“', '')
|
||||
string = string.replace('\n', ' ')
|
||||
string = string.strip()
|
||||
string = tts_preprocessor.preprocess(string)
|
||||
|
||||
if string == '':
|
||||
string = '*Empty reply, try regenerating*'
|
||||
@@ -118,9 +117,10 @@ def output_modifier(string):
|
||||
string += f'\n\n{original_string}'
|
||||
|
||||
shared.processing_message = "*Is typing...*"
|
||||
shared.args.no_stream = streaming_state # restore the streaming option to the previous value
|
||||
shared.args.no_stream = streaming_state # restore the streaming option to the previous value
|
||||
return string
|
||||
|
||||
|
||||
def bot_prefix_modifier(string):
|
||||
"""
|
||||
This function is only applied in chat mode. It modifies
|
||||
@@ -130,38 +130,42 @@ def bot_prefix_modifier(string):
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def ui():
|
||||
# Gradio elements
|
||||
with gr.Accordion("Silero TTS"):
|
||||
with gr.Row():
|
||||
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
|
||||
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
|
||||
|
||||
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
|
||||
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
|
||||
with gr.Row():
|
||||
v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
|
||||
v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
|
||||
|
||||
with gr.Row():
|
||||
convert = gr.Button('Permanently replace audios with the message texts')
|
||||
convert_cancel = gr.Button('Cancel', visible=False)
|
||||
convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
|
||||
|
||||
|
||||
# Convert history with confirmation
|
||||
convert_arr = [convert_confirm, convert, convert_cancel]
|
||||
convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
||||
convert_confirm.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
convert_confirm.click(remove_tts_from_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
|
||||
convert_confirm.click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
convert_cancel.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
||||
convert_confirm.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
convert_confirm.click(remove_tts_from_history, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], shared.gradio['display'])
|
||||
convert_confirm.click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
|
||||
# Toggle message text in history
|
||||
show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
|
||||
show_text.change(toggle_text_in_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
|
||||
show_text.change(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
show_text.change(toggle_text_in_history, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], shared.gradio['display'])
|
||||
show_text.change(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
|
||||
# Event functions to update the parameters in the backend
|
||||
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
||||
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
|
||||
voice.change(lambda x: params.update({"speaker": x}), voice, None)
|
||||
v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
|
||||
v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)
|
||||
v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)
|
||||
|
||||
Reference in New Issue
Block a user