script.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. from pathlib import Path
  2. import gradio as gr
  3. import torch
  4. import time
  5. import re
  6. import modules.shared as shared
  7. import modules.chat as chat
  8. torch._C._jit_set_profiling_mode(False)
  9. params = {
  10. 'activate': True,
  11. 'speaker': 'en_5',
  12. 'language': 'en',
  13. 'model_id': 'v3_en',
  14. 'sample_rate': 48000,
  15. 'device': 'cpu',
  16. 'show_text': False,
  17. 'autoplay': True,
  18. 'voice_pitch': 'medium',
  19. 'voice_speed': 'medium',
  20. }
  21. current_params = params.copy()
  22. voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
  23. voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
  24. voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
  25. last_msg_id = 0
  26. # Used for making text xml compatible, needed for voice pitch and speed control
  27. table = str.maketrans({
  28. "<": "&lt;",
  29. ">": "&gt;",
  30. "&": "&amp;",
  31. "'": "&apos;",
  32. '"': "&quot;",
  33. })
  34. def xmlesc(txt):
  35. return txt.translate(table)
  36. def load_model():
  37. model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
  38. model.to(params['device'])
  39. return model
  40. model = load_model()
  41. def remove_surrounded_chars(string):
  42. new_string = ""
  43. in_star = False
  44. for char in string:
  45. if char == '*':
  46. in_star = not in_star
  47. elif not in_star:
  48. new_string += char
  49. return new_string
  50. def remove_tts_from_history():
  51. suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
  52. for i, entry in enumerate(shared.history['internal']):
  53. reply = entry[1]
  54. reply = re.sub("(<USER>|<user>|{{user}})", shared.settings[f'name1{suffix}'], reply)
  55. if shared.args.chat:
  56. reply = reply.replace('\n', '<br>')
  57. shared.history['visible'][i][1] = reply
  58. if shared.args.cai_chat:
  59. return chat.generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name1{suffix}'], shared.character)
  60. else:
  61. return shared.history['visible']
  62. def toggle_text_in_history():
  63. suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
  64. audio_str='\n\n' # The '\n\n' used after </audio>
  65. if shared.args.chat:
  66. audio_str='<br><br>'
  67. if params['show_text']==True:
  68. #for i, entry in enumerate(shared.history['internal']):
  69. for i, entry in enumerate(shared.history['visible']):
  70. vis_reply = entry[1]
  71. if vis_reply.startswith('<audio'):
  72. reply = shared.history['internal'][i][1]
  73. reply = re.sub("(<USER>|<user>|{{user}})", shared.settings[f'name1{suffix}'], reply)
  74. if shared.args.chat:
  75. reply = reply.replace('\n', '<br>')
  76. shared.history['visible'][i][1] = vis_reply.split(audio_str,1)[0]+audio_str+reply
  77. else:
  78. for i, entry in enumerate(shared.history['visible']):
  79. vis_reply = entry[1]
  80. if vis_reply.startswith('<audio'):
  81. shared.history['visible'][i][1] = vis_reply.split(audio_str,1)[0]+audio_str
  82. if shared.args.cai_chat:
  83. return chat.generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name1{suffix}'], shared.character)
  84. else:
  85. return shared.history['visible']
  86. def input_modifier(string):
  87. """
  88. This function is applied to your text inputs before
  89. they are fed into the model.
  90. """
  91. # Remove autoplay from previous chat history
  92. if (shared.args.chat or shared.args.cai_chat)and len(shared.history['internal'])>0:
  93. [visible_text, visible_reply] = shared.history['visible'][-1]
  94. vis_rep_clean = visible_reply.replace('controls autoplay>','controls>')
  95. shared.history['visible'][-1] = [visible_text, vis_rep_clean]
  96. return string
  97. def output_modifier(string):
  98. """
  99. This function is applied to the model outputs.
  100. """
  101. global model, current_params
  102. for i in params:
  103. if params[i] != current_params[i]:
  104. model = load_model()
  105. current_params = params.copy()
  106. break
  107. if params['activate'] == False:
  108. return string
  109. orig_string = string
  110. string = remove_surrounded_chars(string)
  111. string = string.replace('"', '')
  112. string = string.replace('“', '')
  113. string = string.replace('\n', ' ')
  114. string = string.strip()
  115. silent_string = False # Used to prevent unnecessary audio file generation
  116. if string == '':
  117. string = 'empty reply, try regenerating'
  118. silent_string = True
  119. pitch = params['voice_pitch']
  120. speed = params['voice_speed']
  121. prosody=f'<prosody rate="{speed}" pitch="{pitch}">'
  122. string = '<speak>'+prosody+xmlesc(string)+'</prosody></speak>'
  123. if not shared.still_streaming and not silent_string:
  124. output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{int(time.time())}.wav')
  125. model.save_wav(ssml_text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
  126. autoplay_str = ' autoplay' if params['autoplay'] else ''
  127. string = f'<audio src="file/{output_file.as_posix()}" controls{autoplay_str}></audio>\n\n'
  128. else:
  129. # Placeholder so text doesn't shift around so much
  130. string = '<audio controls></audio>\n\n'
  131. if params['show_text']:
  132. string += orig_string
  133. return string
  134. def bot_prefix_modifier(string):
  135. """
  136. This function is only applied in chat mode. It modifies
  137. the prefix text for the Bot and can be used to bias its
  138. behavior.
  139. """
  140. return string
  141. def ui():
  142. # Gradio elements
  143. with gr.Accordion("Silero TTS"):
  144. with gr.Row():
  145. activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
  146. autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
  147. show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
  148. voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
  149. with gr.Row():
  150. v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
  151. v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
  152. with gr.Row():
  153. convert = gr.Button('Permanently replace chat history audio with message text')
  154. convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
  155. convert_cancel = gr.Button('Cancel', visible=False)
  156. # Convert history with confirmation
  157. convert_arr = [convert_confirm, convert, convert_cancel]
  158. convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
  159. convert_confirm.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
  160. convert_confirm.click(remove_tts_from_history, [], shared.gradio['display'])
  161. convert_confirm.click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
  162. convert_cancel.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
  163. # Toggle message text in history
  164. show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
  165. show_text.change(toggle_text_in_history, [], shared.gradio['display'])
  166. show_text.change(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
  167. # Event functions to update the parameters in the backend
  168. activate.change(lambda x: params.update({"activate": x}), activate, None)
  169. autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
  170. voice.change(lambda x: params.update({"speaker": x}), voice, None)
  171. v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
  172. v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)