script.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import time
  2. from pathlib import Path
  3. import gradio as gr
  4. import torch
  5. from extensions.silero_tts import tts_preprocessor
  6. from modules import chat, shared
  7. from modules.html_generator import chat_html_wrapper
  8. torch._C._jit_set_profiling_mode(False)
  9. params = {
  10. 'activate': True,
  11. 'speaker': 'en_56',
  12. 'language': 'en',
  13. 'model_id': 'v3_en',
  14. 'sample_rate': 48000,
  15. 'device': 'cpu',
  16. 'show_text': False,
  17. 'autoplay': True,
  18. 'voice_pitch': 'medium',
  19. 'voice_speed': 'medium',
  20. 'local_cache_path': '' # User can override the default cache path to something other via settings.json
  21. }
  22. current_params = params.copy()
  23. voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
  24. voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
  25. voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
  26. streaming_state = shared.args.no_stream # remember if chat streaming was enabled
  27. # Used for making text xml compatible, needed for voice pitch and speed control
  28. table = str.maketrans({
  29. "<": "&lt;",
  30. ">": "&gt;",
  31. "&": "&amp;",
  32. "'": "&apos;",
  33. '"': "&quot;",
  34. })
  35. def xmlesc(txt):
  36. return txt.translate(table)
  37. def load_model():
  38. torch_cache_path = torch.hub.get_dir() if params['local_cache_path'] == '' else params['local_cache_path']
  39. model_path = torch_cache_path + "/snakers4_silero-models_master/src/silero/model/" + params['model_id'] + ".pt"
  40. if Path(model_path).is_file():
  41. print(f'\nUsing Silero TTS cached checkpoint found at {torch_cache_path}')
  42. model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True)
  43. else:
  44. print(f'\nSilero TTS cache not found at {torch_cache_path}. Attempting to download...')
  45. model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
  46. model.to(params['device'])
  47. return model
  48. def remove_tts_from_history(name1, name2, mode):
  49. for i, entry in enumerate(shared.history['internal']):
  50. shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
  51. return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
  52. def toggle_text_in_history(name1, name2, mode):
  53. for i, entry in enumerate(shared.history['visible']):
  54. visible_reply = entry[1]
  55. if visible_reply.startswith('<audio'):
  56. if params['show_text']:
  57. reply = shared.history['internal'][i][1]
  58. shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
  59. else:
  60. shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
  61. return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
  62. def input_modifier(string):
  63. """
  64. This function is applied to your text inputs before
  65. they are fed into the model.
  66. """
  67. # Remove autoplay from the last reply
  68. if shared.is_chat() and len(shared.history['internal']) > 0:
  69. shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>', 'controls>')]
  70. shared.processing_message = "*Is recording a voice message...*"
  71. shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
  72. return string
  73. def output_modifier(string):
  74. """
  75. This function is applied to the model outputs.
  76. """
  77. global model, current_params, streaming_state
  78. for i in params:
  79. if params[i] != current_params[i]:
  80. model = load_model()
  81. current_params = params.copy()
  82. break
  83. if not params['activate']:
  84. return string
  85. original_string = string
  86. string = tts_preprocessor.preprocess(string)
  87. if string == '':
  88. string = '*Empty reply, try regenerating*'
  89. else:
  90. output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{int(time.time())}.wav')
  91. prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
  92. silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
  93. model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
  94. autoplay = 'autoplay' if params['autoplay'] else ''
  95. string = f'<audio src="file/{output_file.as_posix()}" controls {autoplay}></audio>'
  96. if params['show_text']:
  97. string += f'\n\n{original_string}'
  98. shared.processing_message = "*Is typing...*"
  99. shared.args.no_stream = streaming_state # restore the streaming option to the previous value
  100. return string
  101. def bot_prefix_modifier(string):
  102. """
  103. This function is only applied in chat mode. It modifies
  104. the prefix text for the Bot and can be used to bias its
  105. behavior.
  106. """
  107. return string
  108. def setup():
  109. global model
  110. model = load_model()
  111. def ui():
  112. # Gradio elements
  113. with gr.Accordion("Silero TTS"):
  114. with gr.Row():
  115. activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
  116. autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
  117. show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
  118. voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
  119. with gr.Row():
  120. v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
  121. v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
  122. with gr.Row():
  123. convert = gr.Button('Permanently replace audios with the message texts')
  124. convert_cancel = gr.Button('Cancel', visible=False)
  125. convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
  126. # Convert history with confirmation
  127. convert_arr = [convert_confirm, convert, convert_cancel]
  128. convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
  129. convert_confirm.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
  130. convert_confirm.click(remove_tts_from_history, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], shared.gradio['display'])
  131. convert_confirm.click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
  132. convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
  133. # Toggle message text in history
  134. show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
  135. show_text.change(toggle_text_in_history, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], shared.gradio['display'])
  136. show_text.change(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
  137. # Event functions to update the parameters in the backend
  138. activate.change(lambda x: params.update({"activate": x}), activate, None)
  139. autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
  140. voice.change(lambda x: params.update({"speaker": x}), voice, None)
  141. v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
  142. v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)