script.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from pathlib import Path
  2. import gradio as gr
  3. import torch
  4. import modules.shared as shared
  5. torch._C._jit_set_profiling_mode(False)
  6. params = {
  7. 'activate': True,
  8. 'speaker': 'en_5',
  9. 'language': 'en',
  10. 'model_id': 'v3_en',
  11. 'sample_rate': 48000,
  12. 'device': 'cpu',
  13. 'show_text': False,
  14. 'autoplay': True,
  15. 'voice_pitch': 'medium',
  16. 'voice_speed': 'medium',
  17. }
  18. current_params = params.copy()
  19. voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
  20. voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
  21. voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
  22. last_msg_id = 0
  23. # Used for making text xml compatible, needed for voice pitch and speed control
  24. table = str.maketrans({
  25. "<": "&lt;",
  26. ">": "&gt;",
  27. "&": "&amp;",
  28. "'": "&apos;",
  29. '"': "&quot;",
  30. })
  31. def xmlesc(txt):
  32. return txt.translate(table)
  33. def load_model():
  34. model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
  35. model.to(params['device'])
  36. return model
  37. model = load_model()
  38. def remove_surrounded_chars(string):
  39. new_string = ""
  40. in_star = False
  41. for char in string:
  42. if char == '*':
  43. in_star = not in_star
  44. elif not in_star:
  45. new_string += char
  46. return new_string
  47. def input_modifier(string):
  48. """
  49. This function is applied to your text inputs before
  50. they are fed into the model.
  51. """
  52. # Remove autoplay from previous
  53. if len(shared.history['internal'])>0:
  54. [text, reply] = shared.history['internal'][-1]
  55. [visible_text, visible_reply] = shared.history['visible'][-1]
  56. rep_clean = reply.replace('controls autoplay>','controls>')
  57. vis_rep_clean = visible_reply.replace('controls autoplay>','controls>')
  58. shared.history['internal'][-1] = [text, rep_clean]
  59. shared.history['visible'][-1] = [visible_text, vis_rep_clean]
  60. return string
  61. def output_modifier(string):
  62. """
  63. This function is applied to the model outputs.
  64. """
  65. global model, current_params
  66. for i in params:
  67. if params[i] != current_params[i]:
  68. model = load_model()
  69. current_params = params.copy()
  70. break
  71. if params['activate'] == False:
  72. return string
  73. orig_string = string
  74. string = remove_surrounded_chars(string)
  75. string = string.replace('"', '')
  76. string = string.replace('“', '')
  77. string = string.replace('\n', ' ')
  78. string = string.strip()
  79. silent_string = False # Used to prevent unnecessary audio file generation
  80. if string == '':
  81. string = 'empty reply, try regenerating'
  82. silent_string = True
  83. # x-slow, slow, medium, fast, x-fast
  84. # x-low, low, medium, high, x-high
  85. pitch = params['voice_pitch']
  86. speed = params['voice_speed']
  87. prosody=f'<prosody rate="{speed}" pitch="{pitch}">'
  88. string = '<speak>'+prosody+xmlesc(string)+'</prosody></speak>'
  89. current_msg_id = len(shared.history['visible']) # Check length here, since output_modifier can run many times on the same message
  90. output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{current_msg_id:06d}.wav')
  91. if not shared.still_streaming and not silent_string:
  92. model.save_wav(ssml_text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
  93. string = f'<audio id="audio_{current_msg_id:06d}" src="file/{output_file.as_posix()}" controls autoplay></audio>\n\n'
  94. else:
  95. # Placeholder so text doesn't shift around so much
  96. string = '<audio controls></audio>\n\n'
  97. if params['show_text']:
  98. #string += f'*[{current_msg_id}]:*'+orig_string #Debug, looks like there is a delay in "current_msg_id" being updated when switching characters (updates after new message sent). Can't find the source. "shared.character" is updating properly.
  99. string += orig_string
  100. return string
  101. def bot_prefix_modifier(string):
  102. """
  103. This function is only applied in chat mode. It modifies
  104. the prefix text for the Bot and can be used to bias its
  105. behavior.
  106. """
  107. return string
  108. def ui():
  109. # Gradio elements
  110. with gr.Accordion("Silero TTS"):
  111. activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
  112. show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
  113. autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
  114. voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
  115. v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
  116. v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
  117. # Event functions to update the parameters in the backend
  118. activate.change(lambda x: params.update({"activate": x}), activate, None)
  119. show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
  120. autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
  121. voice.change(lambda x: params.update({"speaker": x}), voice, None)
  122. v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
  123. v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)