diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index fdee8d9..574b9d8 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -4,6 +4,9 @@ logging.getLogger("urllib3").setLevel(logging.ERROR) logging.getLogger("httpcore").setLevel(logging.ERROR) logging.getLogger("httpx").setLevel(logging.ERROR) logging.getLogger("asyncio").setLevel(logging.ERROR) + +logging.getLogger("charset_normalizer").setLevel(logging.ERROR) +logging.getLogger("torchaudio._extension").setLevel(logging.ERROR) import pdb if os.path.exists("./gweight.txt"): @@ -75,7 +78,7 @@ def get_bert_feature(text, word2ph): with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt") for i in inputs: - inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model + inputs[i] = inputs[i].to(device) res = bert_model(**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] assert len(word2ph) == len(text) @@ -182,6 +185,83 @@ dict_language={ } +def splite_en_inf(sentence, language): + pattern = re.compile(r'[a-zA-Z. ]+') + textlist = [] + langlist = [] + pos = 0 + for match in pattern.finditer(sentence): + start, end = match.span() + if start > pos: + textlist.append(sentence[pos:start]) + langlist.append(language) + textlist.append(sentence[start:end]) + langlist.append("en") + pos = end + if pos < len(sentence): + textlist.append(sentence[pos:]) + langlist.append(language) + + return textlist, langlist + + +def clean_text_inf(text, language): + phones, word2ph, norm_text = clean_text(text, language) + phones = cleaned_text_to_sequence(phones) + + return phones, word2ph, norm_text + + +def get_bert_inf(phones, word2ph, norm_text, language): + if language == "zh": + bert = get_bert_feature(norm_text, word2ph).to(device) + else: + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) + + return bert + + +def nonen_clean_text_inf(text, language): + textlist, langlist = splite_en_inf(text, language) + phones_list = [] + word2ph_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = clean_text_inf(textlist[i], lang) + phones_list.append(phones) + if lang == "en" or "ja": + pass + else: + word2ph_list.append(word2ph) + norm_text_list.append(norm_text) + print(word2ph_list) + phones = sum(phones_list, []) + word2ph = sum(word2ph_list, []) + norm_text = ' '.join(norm_text_list) + + return phones, word2ph, norm_text + + +def nonen_get_bert_inf(text, language): + textlist, langlist = splite_en_inf(text, language) + print(textlist) + print(langlist) + bert_list = [] + for i in range(len(textlist)): + text = textlist[i] + lang = langlist[i] + phones, word2ph, norm_text = clean_text_inf(text, lang) + bert = get_bert_inf(phones, word2ph, norm_text, lang) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + + return bert + + def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): t0 = ttime() prompt_text = prompt_text.strip("\n") @@ -211,30 +291,32 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) t1 = ttime() prompt_language = dict_language[prompt_language] text_language = dict_language[text_language] - phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language) - phones1 = cleaned_text_to_sequence(phones1) + + if prompt_language == "en": + phones1, word2ph1, norm_text1 = clean_text_inf(prompt_text, prompt_language) + else: + phones1, word2ph1, norm_text1 = nonen_clean_text_inf(prompt_text, prompt_language) texts = text.split("\n") audio_opt = [] - - if prompt_language == "zh": - bert1 = get_bert_feature(norm_text1, word2ph1).to(device) + if prompt_language == "en": + bert1 = get_bert_inf(phones1, word2ph1, norm_text1, prompt_language) else: - bert1 = torch.zeros( - (1024, len(phones1)), - dtype=torch.float16 if is_half == True else torch.float32, - ).to(device) + bert1 = nonen_get_bert_inf(prompt_text, prompt_language) for text in texts: # 解决输入目标文本的空行导致报错的问题 if (len(text.strip()) == 0): continue - phones2, word2ph2, norm_text2 = clean_text(text, text_language) - phones2 = cleaned_text_to_sequence(phones2) - - if text_language == "zh": - bert2 = get_bert_feature(norm_text2, word2ph2).to(device) + if text_language == "en": + phones2, word2ph2, norm_text2 = clean_text_inf(text, text_language) else: - bert2 = torch.zeros((1024, len(phones2))).to(bert1) + phones2, word2ph2, norm_text2 = nonen_clean_text_inf(text, text_language) + + if text_language == "en": + bert2 = get_bert_inf(phones2, word2ph2, norm_text2, text_language) + else: + bert2 = nonen_get_bert_inf(text, text_language) + bert = torch.cat([bert1, bert2], 1) all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)