script.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import asyncio
  2. from pathlib import Path
  3. import gradio as gr
  4. import torch
  5. torch._C._jit_set_profiling_mode(False)
  6. params = {
  7. 'activate': True,
  8. 'speaker': 'en_56',
  9. 'language': 'en',
  10. 'model_id': 'v3_en',
  11. 'sample_rate': 48000,
  12. 'device': 'cpu',
  13. }
  14. current_params = params.copy()
  15. wav_idx = 0
  16. def load_model():
  17. model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
  18. model.to(params['device'])
  19. return model
  20. model = load_model()
  21. def remove_surrounded_chars(string):
  22. new_string = ""
  23. in_star = False
  24. for char in string:
  25. if char == '*':
  26. in_star = not in_star
  27. elif not in_star:
  28. new_string += char
  29. return new_string
  30. def input_modifier(string):
  31. """
  32. This function is applied to your text inputs before
  33. they are fed into the model.
  34. """
  35. return string
  36. def output_modifier(string):
  37. """
  38. This function is applied to the model outputs.
  39. """
  40. global wav_idx, model, current_params
  41. for i in params:
  42. if params[i] != current_params[i]:
  43. model = load_model()
  44. current_params = params.copy()
  45. break
  46. if params['activate'] == False:
  47. return string
  48. string = remove_surrounded_chars(string)
  49. string = string.replace('"', '')
  50. string = string.replace('“', '')
  51. string = string.replace('\n', ' ')
  52. string = string.strip()
  53. if string == '':
  54. string = 'empty reply, try regenerating'
  55. output_file = Path(f'extensions/silero_tts/outputs/{wav_idx:06d}.wav')
  56. audio = model.save_wav(text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
  57. string = f'<audio src="file/{output_file.as_posix()}" controls></audio>'
  58. wav_idx += 1
  59. return string
  60. def bot_prefix_modifier(string):
  61. """
  62. This function is only applied in chat mode. It modifies
  63. the prefix text for the Bot and can be used to bias its
  64. behavior.
  65. """
  66. return string
  67. def ui():
  68. # Gradio elements
  69. activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
  70. voice = gr.Dropdown(value=params['speaker'], choices=[f'en_{i}' for i in range(1, 118)], label='TTS voice')
  71. # Event functions to update the parameters in the backend
  72. activate.change(lambda x: params.update({"activate": x}), activate, None)
  73. voice.change(lambda x: params.update({"speaker": x}), voice, None)