瀏覽代碼

Add back the spaces

oobabooga 2 年之前
父節點
當前提交
773e1246da
共有 3 個文件被更改,包括 20 次插入0 次删除
  1. 13 0
      extensions/silero_tts/script.py
  2. 6 0
      extensions/silero_tts/test_tts.py
  3. 1 0
      extensions/silero_tts/tts_preprocessor.py

+ 13 - 0
extensions/silero_tts/script.py

@@ -6,8 +6,10 @@ import torch
 from extensions.silero_tts import tts_preprocessor
 from modules import chat, shared
 
+
 torch._C._jit_set_profiling_mode(False)
 
+
 params = {
     'activate': True,
     'speaker': 'en_56',
@@ -36,20 +38,24 @@ table = str.maketrans({
     '"': """,
 })
 
+
 def xmlesc(txt):
     return txt.translate(table)
 
+
 def load_model():
     model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
     model.to(params['device'])
     return model
 model = load_model()
 
+
 def remove_tts_from_history(name1, name2):
     for i, entry in enumerate(shared.history['internal']):
         shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
     return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
 
+
 def toggle_text_in_history(name1, name2):
     for i, entry in enumerate(shared.history['visible']):
         visible_reply = entry[1]
@@ -61,6 +67,7 @@ def toggle_text_in_history(name1, name2):
                 shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
     return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
 
+
 def input_modifier(string):
     """
     This function is applied to your text inputs before
@@ -75,6 +82,7 @@ def input_modifier(string):
     shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
     return string
 
+
 def output_modifier(string):
     """
     This function is applied to the model outputs.
@@ -111,6 +119,7 @@ def output_modifier(string):
     shared.args.no_stream = streaming_state # restore the streaming option to the previous value
     return string
 
+
 def bot_prefix_modifier(string):
     """
     This function is only applied in chat mode. It modifies
@@ -120,22 +129,26 @@ def bot_prefix_modifier(string):
 
     return string
 
+
 def ui():
     # Gradio elements
     with gr.Accordion("Silero TTS"):
         with gr.Row():
             activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
             autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
+
         show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
         voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
         with gr.Row():
             v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
             v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
+
         with gr.Row():
             convert = gr.Button('Permanently replace audios with the message texts')
             convert_cancel = gr.Button('Cancel', visible=False)
             convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
 
+
     # Convert history with confirmation
     convert_arr = [convert_confirm, convert, convert_cancel]
     convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)

+ 6 - 0
extensions/silero_tts/test_tts.py

@@ -5,8 +5,10 @@ import torch
 
 import tts_preprocessor
 
+
 torch._C._jit_set_profiling_mode(False)
 
+
 params = {
     'activate': True,
     'speaker': 'en_49',
@@ -34,15 +36,18 @@ table = str.maketrans({
     '"': "&quot;",
 })
 
+
 def xmlesc(txt):
     return txt.translate(table)
 
+
 def load_model():
     model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
     model.to(params['device'])
     return model
 model = load_model()
 
+
 def output_modifier(string):
     """
     This function is applied to the model outputs.
@@ -70,6 +75,7 @@ def output_modifier(string):
 
     print(string)
 
+
 if __name__ == '__main__':
     import sys
     output_modifier(sys.argv[1])

+ 1 - 0
extensions/silero_tts/tts_preprocessor.py

@@ -2,6 +2,7 @@ import re
 
 from num2words import num2words
 
+
 alphabet_map = {
     "A": " Ei ",
     "B": " Bee ",