diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 4035915b..e9e84626 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -154,83 +154,6 @@ dict_language={ } -def splite_en_inf(sentence, language): - pattern = re.compile(r'[a-zA-Z. ]+') - textlist = [] - langlist = [] - pos = 0 - for match in pattern.finditer(sentence): - start, end = match.span() - if start > pos: - textlist.append(sentence[pos:start]) - langlist.append(language) - textlist.append(sentence[start:end]) - langlist.append("en") - pos = end - if pos < len(sentence): - textlist.append(sentence[pos:]) - langlist.append(language) - - return textlist, langlist - - -def clean_text_inf(text, language): - phones, word2ph, norm_text = clean_text(text, language) - phones = cleaned_text_to_sequence(phones) - - return phones, word2ph, norm_text - - -def get_bert_inf(phones, word2ph, norm_text, language): - if language == "zh": - bert = get_bert_feature(norm_text, word2ph).to(device) - else: - bert = torch.zeros( - (1024, len(phones)), - dtype=torch.float16 if is_half == True else torch.float32, - ).to(device) - - return bert - - -def nonen_clean_text_inf(text, language): - textlist, langlist = splite_en_inf(text, language) - phones_list = [] - word2ph_list = [] - norm_text_list = [] - for i in range(len(textlist)): - lang = langlist[i] - phones, word2ph, norm_text = clean_text_inf(textlist[i], lang) - phones_list.append(phones) - if lang=="en" or "ja": - pass - else: - word2ph_list.append(word2ph) - norm_text_list.append(norm_text) - print(word2ph_list) - phones = sum(phones_list, []) - word2ph = sum(word2ph_list, []) - norm_text = ' '.join(norm_text_list) - - return phones, word2ph, norm_text - - -def nonen_get_bert_inf(text, language): - textlist, langlist = splite_en_inf(text, language) - print(textlist) - print(langlist) - bert_list = [] - for i in range(len(textlist)): - text = textlist[i] - lang = langlist[i] - phones, word2ph, norm_text = clean_text_inf(text, lang) - bert = get_bert_inf(phones, word2ph, norm_text, lang) - bert_list.append(bert) - bert = torch.cat(bert_list, dim=1) - - return bert - - def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): t0 = ttime() prompt_text = prompt_text.strip("\n") @@ -260,32 +183,27 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) t1 = ttime() prompt_language = dict_language[prompt_language] text_language = dict_language[text_language] - if prompt_language == "en": - phones1, word2ph1, norm_text1 = clean_text_inf(prompt_text, prompt_language) - else: - phones1, word2ph1, norm_text1 = nonen_clean_text_inf(prompt_text, prompt_language) + phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language) + phones1 = cleaned_text_to_sequence(phones1) texts = text.split("\n") audio_opt = [] for text in texts: # 解决输入目标文本的空行导致报错的问题 if (len(text.strip()) == 0): continue - - if text_language == "en": - phones2, word2ph2, norm_text2 = clean_text_inf(text, text_language) + phones2, word2ph2, norm_text2 = clean_text(text, text_language) + phones2 = cleaned_text_to_sequence(phones2) + if prompt_language == "zh": + bert1 = get_bert_feature(norm_text1, word2ph1).to(device) else: - phones2, word2ph2, norm_text2 = nonen_clean_text_inf(text, text_language) - - if prompt_language == "en": - bert1 = get_bert_inf(phones1, word2ph1, norm_text1, prompt_language) + bert1 = torch.zeros( + (1024, len(phones1)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) + if text_language == "zh": + bert2 = get_bert_feature(norm_text2, word2ph2).to(device) else: - bert1 = nonen_get_bert_inf(prompt_text, prompt_language) - - if text_language == "en": - bert2 = get_bert_inf(phones2, word2ph2, norm_text2, text_language) - else: - bert2 = nonen_get_bert_inf(text, text_language) - + bert2 = torch.zeros((1024, len(phones2))).to(bert1) bert = torch.cat([bert1, bert2], 1) all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)