Add back the spaces

This commit is contained in:
oobabooga
2023-04-06 20:45:41 -03:00
parent 7e31bc485c
commit 773e1246da
3 changed files with 20 additions and 0 deletions

View File

@@ -6,8 +6,10 @@ import torch
from extensions.silero_tts import tts_preprocessor from extensions.silero_tts import tts_preprocessor
from modules import chat, shared from modules import chat, shared
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
params = { params = {
'activate': True, 'activate': True,
'speaker': 'en_56', 'speaker': 'en_56',
@@ -36,20 +38,24 @@ table = str.maketrans({
'"': """, '"': """,
}) })
def xmlesc(txt): def xmlesc(txt):
return txt.translate(table) return txt.translate(table)
def load_model(): def load_model():
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id']) model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
model.to(params['device']) model.to(params['device'])
return model return model
model = load_model() model = load_model()
def remove_tts_from_history(name1, name2): def remove_tts_from_history(name1, name2):
for i, entry in enumerate(shared.history['internal']): for i, entry in enumerate(shared.history['internal']):
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]] shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character) return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
def toggle_text_in_history(name1, name2): def toggle_text_in_history(name1, name2):
for i, entry in enumerate(shared.history['visible']): for i, entry in enumerate(shared.history['visible']):
visible_reply = entry[1] visible_reply = entry[1]
@@ -61,6 +67,7 @@ def toggle_text_in_history(name1, name2):
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"] shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character) return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
def input_modifier(string): def input_modifier(string):
""" """
This function is applied to your text inputs before This function is applied to your text inputs before
@@ -75,6 +82,7 @@ def input_modifier(string):
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
return string return string
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
@@ -111,6 +119,7 @@ def output_modifier(string):
shared.args.no_stream = streaming_state # restore the streaming option to the previous value shared.args.no_stream = streaming_state # restore the streaming option to the previous value
return string return string
def bot_prefix_modifier(string): def bot_prefix_modifier(string):
""" """
This function is only applied in chat mode. It modifies This function is only applied in chat mode. It modifies
@@ -120,22 +129,26 @@ def bot_prefix_modifier(string):
return string return string
def ui(): def ui():
# Gradio elements # Gradio elements
with gr.Accordion("Silero TTS"): with gr.Accordion("Silero TTS"):
with gr.Row(): with gr.Row():
activate = gr.Checkbox(value=params['activate'], label='Activate TTS') activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically') autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player') show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice') voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
with gr.Row(): with gr.Row():
v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch') v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed') v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
with gr.Row(): with gr.Row():
convert = gr.Button('Permanently replace audios with the message texts') convert = gr.Button('Permanently replace audios with the message texts')
convert_cancel = gr.Button('Cancel', visible=False) convert_cancel = gr.Button('Cancel', visible=False)
convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False) convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
# Convert history with confirmation # Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel] convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr) convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)

View File

@@ -5,8 +5,10 @@ import torch
import tts_preprocessor import tts_preprocessor
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
params = { params = {
'activate': True, 'activate': True,
'speaker': 'en_49', 'speaker': 'en_49',
@@ -34,15 +36,18 @@ table = str.maketrans({
'"': "&quot;", '"': "&quot;",
}) })
def xmlesc(txt): def xmlesc(txt):
return txt.translate(table) return txt.translate(table)
def load_model(): def load_model():
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id']) model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
model.to(params['device']) model.to(params['device'])
return model return model
model = load_model() model = load_model()
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
@@ -70,6 +75,7 @@ def output_modifier(string):
print(string) print(string)
if __name__ == '__main__': if __name__ == '__main__':
import sys import sys
output_modifier(sys.argv[1]) output_modifier(sys.argv[1])

View File

@@ -2,6 +2,7 @@ import re
from num2words import num2words from num2words import num2words
alphabet_map = { alphabet_map = {
"A": " Ei ", "A": " Ei ",
"B": " Bee ", "B": " Bee ",