|
@@ -21,6 +21,7 @@ params = {
|
|
|
'autoplay': True,
|
|
'autoplay': True,
|
|
|
'voice_pitch': 'medium',
|
|
'voice_pitch': 'medium',
|
|
|
'voice_speed': 'medium',
|
|
'voice_speed': 'medium',
|
|
|
|
|
+ 'local_cache_path': '' # User can override the default cache path to something other via settings.json
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
current_params = params.copy()
|
|
current_params = params.copy()
|
|
@@ -44,14 +45,18 @@ def xmlesc(txt):
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model():
|
|
def load_model():
|
|
|
- model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
|
|
|
|
|
|
+ torch_cache_path = torch.hub.get_dir() if params['local_cache_path'] == '' else params['local_cache_path']
|
|
|
|
|
+ model_path = torch_cache_path + "/snakers4_silero-models_master/src/silero/model/" + params['model_id'] + ".pt"
|
|
|
|
|
+ if Path(model_path).is_file():
|
|
|
|
|
+ print(f'\nUsing Silero TTS cached checkpoint found at {torch_cache_path}')
|
|
|
|
|
+ model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True)
|
|
|
|
|
+ else:
|
|
|
|
|
+ print(f'\nSilero TTS cache not found at {torch_cache_path}. Attempting to download...')
|
|
|
|
|
+ model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
|
|
model.to(params['device'])
|
|
model.to(params['device'])
|
|
|
return model
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
-model = load_model()
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
def remove_tts_from_history(name1, name2, mode):
|
|
def remove_tts_from_history(name1, name2, mode):
|
|
|
for i, entry in enumerate(shared.history['internal']):
|
|
for i, entry in enumerate(shared.history['internal']):
|
|
|
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
|
|
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
|
|
@@ -132,6 +137,11 @@ def bot_prefix_modifier(string):
|
|
|
return string
|
|
return string
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def setup():
|
|
|
|
|
+ global model
|
|
|
|
|
+ model = load_model()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def ui():
|
|
def ui():
|
|
|
# Gradio elements
|
|
# Gradio elements
|
|
|
with gr.Accordion("Silero TTS"):
|
|
with gr.Accordion("Silero TTS"):
|