Prechádzať zdrojové kódy

Silero TTS offline cache (#628)

Φφ 2 rokov pred
rodič
commit
e563b015d8
1 zmenil súbory, kde vykonal 14 pridanie a 4 odobranie
  1. 14 4
      extensions/silero_tts/script.py

+ 14 - 4
extensions/silero_tts/script.py

@@ -21,6 +21,7 @@ params = {
     'autoplay': True,
     'voice_pitch': 'medium',
     'voice_speed': 'medium',
+    'local_cache_path': ''  # User can override the default cache path to something other via settings.json
 }
 
 current_params = params.copy()
@@ -44,14 +45,18 @@ def xmlesc(txt):
 
 
 def load_model():
-    model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
+    torch_cache_path = torch.hub.get_dir() if params['local_cache_path'] == '' else params['local_cache_path']
+    model_path = torch_cache_path + "/snakers4_silero-models_master/src/silero/model/" + params['model_id'] + ".pt"
+    if Path(model_path).is_file():
+        print(f'\nUsing Silero TTS cached checkpoint found at {torch_cache_path}')
+        model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True)
+    else:
+        print(f'\nSilero TTS cache not found at {torch_cache_path}. Attempting to download...')
+        model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
     model.to(params['device'])
     return model
 
 
-model = load_model()
-
-
 def remove_tts_from_history(name1, name2, mode):
     for i, entry in enumerate(shared.history['internal']):
         shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
@@ -132,6 +137,11 @@ def bot_prefix_modifier(string):
     return string
 
 
+def setup():
+    global model
+    model = load_model()
+
+
 def ui():
     # Gradio elements
     with gr.Accordion("Silero TTS"):