diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 7626bc4..1c5dab6 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -165,6 +165,83 @@ dict_language={ } +def splite_en_inf(sentence, language): + pattern = re.compile(r'[a-zA-Z. ]+') + textlist = [] + langlist = [] + pos = 0 + for match in pattern.finditer(sentence): + start, end = match.span() + if start > pos: + textlist.append(sentence[pos:start]) + langlist.append(language) + textlist.append(sentence[start:end]) + langlist.append("en") + pos = end + if pos < len(sentence): + textlist.append(sentence[pos:]) + langlist.append(language) + + return textlist, langlist + + +def clean_text_inf(text, language): + phones, word2ph, norm_text = clean_text(text, language) + phones = cleaned_text_to_sequence(phones) + + return phones, word2ph, norm_text + + +def get_bert_inf(phones, word2ph, norm_text, language): + if language == "zh": + bert = get_bert_feature(norm_text, word2ph).to(device) + else: + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) + + return bert + + +def nonen_clean_text_inf(text, language): + textlist, langlist = splite_en_inf(text, language) + phones_list = [] + word2ph_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = clean_text_inf(textlist[i], lang) + phones_list.append(phones) + if lang=="en" or "ja": + pass + else: + word2ph_list.append(word2ph) + norm_text_list.append(norm_text) + print(word2ph_list) + phones = sum(phones_list, []) + word2ph = sum(word2ph_list, []) + norm_text = ' '.join(norm_text_list) + + return phones, word2ph, norm_text + + +def nonen_get_bert_inf(text, language): + textlist, langlist = splite_en_inf(text, language) + print(textlist) + print(langlist) + bert_list = [] + for i in range(len(textlist)): + text = textlist[i] + lang = langlist[i] + phones, word2ph, norm_text = clean_text_inf(text, lang) + bert = get_bert_inf(phones, word2ph, norm_text, lang) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + + return bert + + def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): t0 = ttime() prompt_text = prompt_text.strip("\n") @@ -194,27 +271,32 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) t1 = ttime() prompt_language = dict_language[prompt_language] text_language = dict_language[text_language] - phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language) - phones1 = cleaned_text_to_sequence(phones1) + if prompt_language == "en": + phones1, word2ph1, norm_text1 = clean_text_inf(prompt_text, prompt_language) + else: + phones1, word2ph1, norm_text1 = nonen_clean_text_inf(prompt_text, prompt_language) texts = text.split("\n") audio_opt = [] for text in texts: # 解决输入目标文本的空行导致报错的问题 if (len(text.strip()) == 0): continue - phones2, word2ph2, norm_text2 = clean_text(text, text_language) - phones2 = cleaned_text_to_sequence(phones2) - if prompt_language == "zh": - bert1 = get_bert_feature(norm_text1, word2ph1).to(device) + + if text_language == "en": + phones2, word2ph2, norm_text2 = clean_text_inf(text, text_language) else: - bert1 = torch.zeros( - (1024, len(phones1)), - dtype=torch.float16 if is_half == True else torch.float32, - ).to(device) - if text_language == "zh": - bert2 = get_bert_feature(norm_text2, word2ph2).to(device) + phones2, word2ph2, norm_text2 = nonen_clean_text_inf(text, text_language) + + if prompt_language == "en": + bert1 = get_bert_inf(phones1, word2ph1, norm_text1, prompt_language) else: - bert2 = torch.zeros((1024, len(phones2))).to(bert1) + bert1 = nonen_get_bert_inf(prompt_text, prompt_language) + + if text_language == "en": + bert2 = get_bert_inf(phones2, word2ph2, norm_text2, text_language) + else: + bert2 = nonen_get_bert_inf(text, text_language) + bert = torch.cat([bert1, bert2], 1) all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)