script.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import asyncio
  2. import io
  3. import json
  4. import os
  5. from pathlib import Path
  6. import gradio as gr
  7. import requests
  8. import torch
  9. from elevenlabslib import *
  10. from elevenlabslib.helpers import *
  11. params = {
  12. 'activate': True,
  13. 'api_key': '12345',
  14. 'selected_voice': 'None',
  15. }
  16. initial_voice = ['None']
  17. wav_idx = 0
  18. user = ElevenLabsUser(params['api_key'])
  19. user_info = None
  20. "Check if the API is valid and refresh the UI accordingly."
  21. def check_valid_api():
  22. global user, user_info, params
  23. user = ElevenLabsUser(params['api_key'])
  24. user_info = user._get_subscription_data()
  25. print('checking api')
  26. if params['activate'] == False:
  27. return gr.update(value='Disconnected')
  28. elif user_info is None:
  29. print('Incorrect API Key')
  30. return gr.update(value='Disconnected')
  31. else:
  32. print('Got an API Key!')
  33. return gr.update(value='Connected')
  34. "Once the API is verified, get the available voices and update the dropdown list"
  35. def refresh_voices():
  36. global user, user_info
  37. your_voices = [None]
  38. if user_info is not None:
  39. for voice in user.get_available_voices():
  40. your_voices.append(voice.initialName)
  41. return gr.Dropdown.update(choices=your_voices)
  42. else:
  43. return
  44. def remove_surrounded_chars(string):
  45. new_string = ""
  46. in_star = False
  47. for char in string:
  48. if char == '*':
  49. in_star = not in_star
  50. elif not in_star:
  51. new_string += char
  52. return new_string
  53. def input_modifier(string):
  54. """
  55. This function is applied to your text inputs before
  56. they are fed into the model.
  57. """
  58. return string
  59. def output_modifier(string):
  60. """
  61. This function is applied to the model outputs.
  62. """
  63. global params, wav_idx, user, user_info
  64. if params['activate'] == False:
  65. return string
  66. elif user_info == None:
  67. return string
  68. string = remove_surrounded_chars(string)
  69. string = string.replace('"', '')
  70. string = string.replace('“', '')
  71. string = string.replace('\n', ' ')
  72. string = string.strip()
  73. if string == '':
  74. string = 'empty reply, try regenerating'
  75. output_file = Path('extensions/elevenlabs_tts/outputs/{}.wav'.format(wav_idx))
  76. voice = user.get_voices_by_name(params['selected_voice'])[0]
  77. audio_data = voice.generate_audio_bytes(string)
  78. save_bytes_to_path("extensions/elevenlabs_tts/outputs/{}.wav".format(wav_idx), audio_data)
  79. string = f'<audio src="file/{output_file.as_posix()}" controls></audio>'
  80. wav_idx += 1
  81. return string
  82. def ui():
  83. # Gradio elements
  84. with gr.Row():
  85. activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
  86. connection_status = gr.Textbox(value='Disconnected', label='Connection Status')
  87. voice = gr.Dropdown(value=params['selected_voice'], choices=initial_voice, label='TTS Voice')
  88. with gr.Row():
  89. api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
  90. connect = gr.Button(value='Connect')
  91. # Event functions to update the parameters in the backend
  92. activate.change(lambda x: params.update({'activate': x}), activate, None)
  93. voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
  94. api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
  95. connect.click(check_valid_api, [], connection_status)
  96. connect.click(refresh_voices, [], voice)