script.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import asyncio
  2. from pathlib import Path
  3. import gradio as gr
  4. import torch
  5. torch._C._jit_set_profiling_mode(False)
  6. params = {
  7. 'activate': True,
  8. 'speaker': 'en_56',
  9. 'language': 'en',
  10. 'model_id': 'v3_en',
  11. 'sample_rate': 48000,
  12. 'device': 'cpu',
  13. }
  14. current_params = params.copy()
  15. voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
  16. wav_idx = 0
  17. def load_model():
  18. model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
  19. model.to(params['device'])
  20. return model
  21. model = load_model()
  22. def remove_surrounded_chars(string):
  23. new_string = ""
  24. in_star = False
  25. for char in string:
  26. if char == '*':
  27. in_star = not in_star
  28. elif not in_star:
  29. new_string += char
  30. return new_string
  31. def input_modifier(string):
  32. """
  33. This function is applied to your text inputs before
  34. they are fed into the model.
  35. """
  36. return string
  37. def output_modifier(string):
  38. """
  39. This function is applied to the model outputs.
  40. """
  41. global wav_idx, model, current_params
  42. for i in params:
  43. if params[i] != current_params[i]:
  44. model = load_model()
  45. current_params = params.copy()
  46. break
  47. if params['activate'] == False:
  48. return string
  49. string = remove_surrounded_chars(string)
  50. string = string.replace('"', '')
  51. string = string.replace('“', '')
  52. string = string.replace('\n', ' ')
  53. string = string.strip()
  54. if string == '':
  55. string = 'empty reply, try regenerating'
  56. output_file = Path(f'extensions/silero_tts/outputs/{wav_idx:06d}.wav')
  57. audio = model.save_wav(text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
  58. string = f'<audio src="file/{output_file.as_posix()}" controls></audio>'
  59. wav_idx += 1
  60. return string
  61. def bot_prefix_modifier(string):
  62. """
  63. This function is only applied in chat mode. It modifies
  64. the prefix text for the Bot and can be used to bias its
  65. behavior.
  66. """
  67. return string
  68. def ui():
  69. # Gradio elements
  70. activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
  71. voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
  72. # Event functions to update the parameters in the backend
  73. activate.change(lambda x: params.update({"activate": x}), activate, None)
  74. voice.change(lambda x: params.update({"speaker": x}), voice, None)