فهرست منبع

Improve Silero's Preprocessor to Handle Numbers and Abbreviations Better

da3dsoul 2 سال پیش
والد
کامیت
b2022d0869
3فایلهای تغییر یافته به همراه119 افزوده شده و 11 حذف شده
  1. 1 0
      extensions/silero_tts/requirements.txt
  2. 3 11
      extensions/silero_tts/script.py
  3. 115 0
      extensions/silero_tts/tts_preprocessor.py

+ 1 - 0
extensions/silero_tts/requirements.txt

@@ -1,4 +1,5 @@
 ipython
+num2words
 omegaconf
 pydub
 PyYAML

+ 3 - 11
extensions/silero_tts/script.py

@@ -1,4 +1,3 @@
-import re
 import time
 from pathlib import Path
 
@@ -7,6 +6,8 @@ import modules.chat as chat
 import modules.shared as shared
 import torch
 
+from extensions.silero_tts import tts_preprocessor
+
 torch._C._jit_set_profiling_mode(False)
 
 params = {
@@ -46,11 +47,6 @@ def load_model():
     return model
 model = load_model()
 
-def remove_surrounded_chars(string):
-    # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
-    # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
-    return re.sub('\*[^\*]*?(\*|$)','',string)
-
 def remove_tts_from_history(name1, name2):
     for i, entry in enumerate(shared.history['internal']):
         shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
@@ -98,11 +94,7 @@ def output_modifier(string):
         return string
 
     original_string = string
-    string = remove_surrounded_chars(string)
-    string = string.replace('"', '')
-    string = string.replace('“', '')
-    string = string.replace('\n', ' ')
-    string = string.strip()
+    string = tts_preprocessor.preprocess(string)
 
     if string == '':
         string = '*Empty reply, try regenerating*'

+ 115 - 0
extensions/silero_tts/tts_preprocessor.py

@@ -0,0 +1,115 @@
+import re
+from num2words import num2words
+
+
+alphabet_map = {
+    "A": " Ei ",
+    "B": " Bee ",
+    "C": " See ",
+    "D": " Dee ",
+    "E": " II ",
+    "F": " Eff ",
+    "G": " Jee ",
+    "H": " Eich ",
+    "I": " Eye ",
+    "J": " Jay ",
+    "K": " Kay ",
+    "L": " El ",
+    "M": " Emm ",
+    "N": " Enn ",
+    "O": " Ohh ",
+    "P": " Pii ",
+    "Q": " Queue ",
+    "R": " Are ",
+    "S": " Ess ",
+    "T": " Tee ",
+    "U": " You ",
+    "V": " Vii ",
+    "W": " Double You ",
+    "X": " Ex ",
+    "Y": " Why ",
+    "Z": "Zed"  # Zed is weird, as I (da3dsoul) am American, but most of the voice models sound British, so it matches
+}
+
+
+def preprocess(string):
+    string = remove_surrounded_chars(string)
+    string = string.replace('"', '')
+    string = string.replace('“', '')
+    string = string.replace('\n', ' ')
+    string = remove_commas(string)
+    string = hyphen_range_to(string)
+    string = num_to_words(string)
+    string = string.strip()
+    # TODO Try to use a ML predictor to expand abbreviations. It's hard, dependent on context, and whether to actually
+    # try to say the abbreviation or spell it out as I've done below is not agreed upon
+
+    # For now, expand abbreviations to pronunciations
+    string = replace_abbreviations(string)
+
+    return string
+
+
+def replace_abbreviations(string):
+    pattern = re.compile(r'[\s("\'\[<][A-Z]{2,4}[\s,.?!)"\'\]>]')
+    result = string
+    while True:
+        match = pattern.search(result)
+        if match is None:
+            break
+
+        start = match.start()
+        end = match.end()
+        result = result[0:start] + replace_abbreviation(result[start:end]) + result[end:len(result)]
+
+    return result
+
+
+def replace_abbreviation(string):
+    result = ""
+    for char in string:
+        result = match_mapping(char, result)
+
+    return result
+
+
+def match_mapping(char, result):
+    for mapping in alphabet_map.keys():
+        if char == mapping:
+            return result + alphabet_map[char]
+
+    return result + char
+
+
+def remove_surrounded_chars(string):
+    # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
+    # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
+    return re.sub(r'\*[^*]*?(\*|$)', '', string)
+
+
+def hyphen_range_to(text):
+    pattern = re.compile(r'(\d+)[-–](\d+)')
+    result = pattern.sub(lambda x: x.group(1) + ' to ' + x.group(2), text)
+    return result
+
+
+def num_to_words(text):
+    pattern = re.compile(r'\d+')
+    result = pattern.sub(lambda x: num2words(int(x.group())), text)
+    return result
+
+
+def remove_commas(text):
+    import re
+    pattern = re.compile(r'(\d),(\d)')
+    result = pattern.sub(r'\1\2', text)
+    return result
+
+
+def __main__(args):
+    print(preprocess(args[1]))
+
+
+if __name__ == "__main__":
+    import sys
+    __main__(sys.argv)