From 2310bcde5378930a1472570be0f54d766616b04b Mon Sep 17 00:00:00 2001 From: KamioRinn <63162909+KamioRinn@users.noreply.github.com> Date: Sat, 10 Aug 2024 12:28:53 +0800 Subject: [PATCH] Optimize short sentence (#1430) --- GPT_SoVITS/inference_webui.py | 7 +++++-- GPT_SoVITS/text/cleaner.py | 2 +- api.py | 5 ++++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 727b9f7..878f8d8 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -299,7 +299,7 @@ def get_first(text): return text from text import chinese -def get_phones_and_bert(text,language,version): +def get_phones_and_bert(text,language,version,final=False): if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: language = language.replace("all_","") if language == "en": @@ -366,6 +366,9 @@ def get_phones_and_bert(text,language,version): phones = sum(phones_list, []) norm_text = ''.join(norm_text_list) + if not final and len(phones) < 6: + return get_phones_and_bert("." + text,language,version,final=True) + return phones,bert.to(dtype),norm_text @@ -408,7 +411,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "." print(i18n("实际输入的参考文本:"), prompt_text) text = text.strip("\n") - if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text + # if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text print(i18n("实际输入的目标文本:"), text) zero_wav = np.zeros( diff --git a/GPT_SoVITS/text/cleaner.py b/GPT_SoVITS/text/cleaner.py index 1091a34..298e4d2 100644 --- a/GPT_SoVITS/text/cleaner.py +++ b/GPT_SoVITS/text/cleaner.py @@ -45,7 +45,7 @@ def clean_text(text, language, version=None): elif language == "en": phones = language_module.g2p(norm_text) if len(phones) < 4: - phones = [','] * (4 - len(phones)) + phones + phones = [','] + phones word2ph = None else: phones = language_module.g2p(norm_text) diff --git a/api.py b/api.py index e510ab9..3b17394 100644 --- a/api.py +++ b/api.py @@ -275,7 +275,7 @@ def get_bert_inf(phones, word2ph, norm_text, language): return bert from text import chinese -def get_phones_and_bert(text,language,version): +def get_phones_and_bert(text,language,version,final=False): if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: language = language.replace("all_","") if language == "en": @@ -340,6 +340,9 @@ def get_phones_and_bert(text,language,version): phones = sum(phones_list, []) norm_text = ''.join(norm_text_list) + if not final and len(phones) < 6: + return get_phones_and_bert("." + text,language,version,final=True) + return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text