script.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from pathlib import Path
  2. import gradio as gr
  3. import torch
  4. import modules.shared as shared
  5. import simpleaudio as sa
  6. torch._C._jit_set_profiling_mode(False)
  7. params = {
  8. 'activate': True,
  9. 'speaker': 'en_5',
  10. 'language': 'en',
  11. 'model_id': 'v3_en',
  12. 'sample_rate': 48000,
  13. 'device': 'cpu',
  14. 'max_wavs': 20,
  15. 'play_audio': True,
  16. 'show_text': True,
  17. }
  18. current_params = params.copy()
  19. voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
  20. wav_idx = 0
  21. table = str.maketrans({
  22. "<": "&lt;",
  23. ">": "&gt;",
  24. "&": "&amp;",
  25. "'": "&apos;",
  26. '"': "&quot;",
  27. })
  28. def xmlesc(txt):
  29. return txt.translate(table)
  30. def load_model():
  31. model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
  32. model.to(params['device'])
  33. return model
  34. model = load_model()
  35. def remove_surrounded_chars(string):
  36. new_string = ""
  37. in_star = False
  38. for char in string:
  39. if char == '*':
  40. in_star = not in_star
  41. elif not in_star:
  42. new_string += char
  43. return new_string
  44. def input_modifier(string):
  45. """
  46. This function is applied to your text inputs before
  47. they are fed into the model.
  48. """
  49. return string
  50. def output_modifier(string):
  51. """
  52. This function is applied to the model outputs.
  53. """
  54. global wav_idx, model, current_params
  55. for i in params:
  56. if params[i] != current_params[i]:
  57. model = load_model()
  58. current_params = params.copy()
  59. break
  60. if params['activate'] == False:
  61. return string
  62. orig_string = string
  63. string = remove_surrounded_chars(string)
  64. string = string.replace('"', '')
  65. string = string.replace('“', '')
  66. string = string.replace('\n', ' ')
  67. string = string.strip()
  68. auto_playable=True
  69. if string == '':
  70. string = 'empty reply, try regenerating'
  71. auto_playable=False
  72. #x-slow, slow, medium, fast, x-fast
  73. #x-low, low, medium, high, x-high
  74. #prosody='<prosody rate="fast" pitch="medium">'
  75. prosody='<prosody rate="fast">'
  76. string ='<speak>'+prosody+xmlesc(string)+'</prosody></speak>'
  77. output_file = Path(f'extensions/silero_tts/outputs/{wav_idx:06d}.wav')
  78. model.save_wav(ssml_text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
  79. string = f'<audio src="file/{output_file.as_posix()}" controls></audio>'
  80. #reset if too many wavs. set max to -1 for unlimited.
  81. if wav_idx < params['max_wavs'] or params['max_wavs'] < 0:
  82. #only increment if starting a new stream, else replace during streaming. Does not update duration on webui sometimes?
  83. if not shared.still_streaming:
  84. wav_idx += 1
  85. else:
  86. wav_idx = 0
  87. if params['show_text']:
  88. string+='\n\n'+orig_string
  89. #if params['play_audio'] == True and auto_playable and shared.stop_everything:
  90. if params['play_audio'] == True and auto_playable and not shared.still_streaming:
  91. stop_autoplay()
  92. wave_obj = sa.WaveObject.from_wave_file(output_file.as_posix())
  93. wave_obj.play()
  94. return string
  95. def bot_prefix_modifier(string):
  96. """
  97. This function is only applied in chat mode. It modifies
  98. the prefix text for the Bot and can be used to bias its
  99. behavior.
  100. """
  101. return string
  102. def stop_autoplay():
  103. sa.stop_all()
  104. def ui():
  105. # Gradio elements
  106. activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
  107. show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
  108. play_audio = gr.Checkbox(value=params['play_audio'], label='Play TTS automatically')
  109. stop_audio = gr.Button("Stop Auto-Play")
  110. voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
  111. # Event functions to update the parameters in the backend
  112. activate.change(lambda x: params.update({"activate": x}), activate, None)
  113. play_audio.change(lambda x: params.update({"play_audio": x}), play_audio, None)
  114. show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
  115. stop_audio.click(stop_autoplay)
  116. voice.change(lambda x: params.update({"speaker": x}), voice, None)