Improve Silero's Preprocessor to Handle Numbers and Abbreviations Better
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
ipython
|
||||
num2words
|
||||
omegaconf
|
||||
pydub
|
||||
PyYAML
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
@@ -7,6 +6,8 @@ import modules.chat as chat
|
||||
import modules.shared as shared
|
||||
import torch
|
||||
|
||||
from extensions.silero_tts import tts_preprocessor
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
params = {
|
||||
@@ -46,11 +47,6 @@ def load_model():
|
||||
return model
|
||||
model = load_model()
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||
|
||||
def remove_tts_from_history(name1, name2):
|
||||
for i, entry in enumerate(shared.history['internal']):
|
||||
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
|
||||
@@ -98,11 +94,7 @@ def output_modifier(string):
|
||||
return string
|
||||
|
||||
original_string = string
|
||||
string = remove_surrounded_chars(string)
|
||||
string = string.replace('"', '')
|
||||
string = string.replace('“', '')
|
||||
string = string.replace('\n', ' ')
|
||||
string = string.strip()
|
||||
string = tts_preprocessor.preprocess(string)
|
||||
|
||||
if string == '':
|
||||
string = '*Empty reply, try regenerating*'
|
||||
|
||||
115
extensions/silero_tts/tts_preprocessor.py
Normal file
115
extensions/silero_tts/tts_preprocessor.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import re
|
||||
from num2words import num2words
|
||||
|
||||
|
||||
alphabet_map = {
|
||||
"A": " Ei ",
|
||||
"B": " Bee ",
|
||||
"C": " See ",
|
||||
"D": " Dee ",
|
||||
"E": " II ",
|
||||
"F": " Eff ",
|
||||
"G": " Jee ",
|
||||
"H": " Eich ",
|
||||
"I": " Eye ",
|
||||
"J": " Jay ",
|
||||
"K": " Kay ",
|
||||
"L": " El ",
|
||||
"M": " Emm ",
|
||||
"N": " Enn ",
|
||||
"O": " Ohh ",
|
||||
"P": " Pii ",
|
||||
"Q": " Queue ",
|
||||
"R": " Are ",
|
||||
"S": " Ess ",
|
||||
"T": " Tee ",
|
||||
"U": " You ",
|
||||
"V": " Vii ",
|
||||
"W": " Double You ",
|
||||
"X": " Ex ",
|
||||
"Y": " Why ",
|
||||
"Z": "Zed" # Zed is weird, as I (da3dsoul) am American, but most of the voice models sound British, so it matches
|
||||
}
|
||||
|
||||
|
||||
def preprocess(string):
|
||||
string = remove_surrounded_chars(string)
|
||||
string = string.replace('"', '')
|
||||
string = string.replace('“', '')
|
||||
string = string.replace('\n', ' ')
|
||||
string = remove_commas(string)
|
||||
string = hyphen_range_to(string)
|
||||
string = num_to_words(string)
|
||||
string = string.strip()
|
||||
# TODO Try to use a ML predictor to expand abbreviations. It's hard, dependent on context, and whether to actually
|
||||
# try to say the abbreviation or spell it out as I've done below is not agreed upon
|
||||
|
||||
# For now, expand abbreviations to pronunciations
|
||||
string = replace_abbreviations(string)
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def replace_abbreviations(string):
|
||||
pattern = re.compile(r'[\s("\'\[<][A-Z]{2,4}[\s,.?!)"\'\]>]')
|
||||
result = string
|
||||
while True:
|
||||
match = pattern.search(result)
|
||||
if match is None:
|
||||
break
|
||||
|
||||
start = match.start()
|
||||
end = match.end()
|
||||
result = result[0:start] + replace_abbreviation(result[start:end]) + result[end:len(result)]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def replace_abbreviation(string):
|
||||
result = ""
|
||||
for char in string:
|
||||
result = match_mapping(char, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def match_mapping(char, result):
|
||||
for mapping in alphabet_map.keys():
|
||||
if char == mapping:
|
||||
return result + alphabet_map[char]
|
||||
|
||||
return result + char
|
||||
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
return re.sub(r'\*[^*]*?(\*|$)', '', string)
|
||||
|
||||
|
||||
def hyphen_range_to(text):
|
||||
pattern = re.compile(r'(\d+)[-–](\d+)')
|
||||
result = pattern.sub(lambda x: x.group(1) + ' to ' + x.group(2), text)
|
||||
return result
|
||||
|
||||
|
||||
def num_to_words(text):
|
||||
pattern = re.compile(r'\d+')
|
||||
result = pattern.sub(lambda x: num2words(int(x.group())), text)
|
||||
return result
|
||||
|
||||
|
||||
def remove_commas(text):
|
||||
import re
|
||||
pattern = re.compile(r'(\d),(\d)')
|
||||
result = pattern.sub(r'\1\2', text)
|
||||
return result
|
||||
|
||||
|
||||
def __main__(args):
|
||||
print(preprocess(args[1]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
__main__(sys.argv)
|
||||
Reference in New Issue
Block a user