From 436032214a10cf7d2f7c2107ea26b8590aabc60e Mon Sep 17 00:00:00 2001 From: KamioRinn Date: Wed, 27 Mar 2024 00:19:49 +0800 Subject: [PATCH] Add zh/jp/en mix --- api.py | 116 ++++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 94 insertions(+), 22 deletions(-) diff --git a/api.py b/api.py index 34adfbe9..e763c42e 100644 --- a/api.py +++ b/api.py @@ -111,6 +111,7 @@ sys.path.append(now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) import signal +import LangSegment from time import time as ttime import torch import librosa @@ -249,6 +250,8 @@ def change_sovits_weights(sovits_path): print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) with open("./sweight.txt", "w", encoding="utf-8") as f: f.write(sovits_path) + + def change_gpt_weights(gpt_path): global hz, max_sec, t2s_model, config hz = 50 @@ -283,6 +286,83 @@ def get_bert_feature(text, word2ph): return phone_level_feature.T +def clean_text_inf(text, language): + phones, word2ph, norm_text = clean_text(text, language) + phones = cleaned_text_to_sequence(phones) + return phones, word2ph, norm_text + + +def get_bert_inf(phones, word2ph, norm_text, language): + language=language.replace("all_","") + if language == "zh": + bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype) + else: + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) + + return bert + + +def get_phones_and_bert(text,language): + if language in {"en","all_zh","all_ja"}: + language = language.replace("all_","") + if language == "en": + LangSegment.setfilters(["en"]) + formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text)) + else: + # 因无法区别中日文汉字,以用户输入为准 + formattext = text + while " " in formattext: + formattext = formattext.replace(" ", " ") + phones, word2ph, norm_text = clean_text_inf(formattext, language) + if language == "zh": + bert = get_bert_feature(norm_text, word2ph).to(device) + else: + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) + elif language in {"zh", "ja","auto"}: + textlist=[] + langlist=[] + LangSegment.setfilters(["zh","ja","en","ko"]) + if language == "auto": + for tmp in LangSegment.getTexts(text): + if tmp["lang"] == "ko": + langlist.append("zh") + textlist.append(tmp["text"]) + else: + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegment.getTexts(text): + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + # 因无法区别中日文汉字,以用户输入为准 + langlist.append(language) + textlist.append(tmp["text"]) + print(textlist) + print(langlist) + phones_list = [] + bert_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = clean_text_inf(textlist[i], lang) + bert = get_bert_inf(phones, word2ph, norm_text, lang) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = ''.join(norm_text_list) + + return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text + + n_semantic = 1024 dict_s2 = torch.load(sovits_path, map_location="cpu") hps = dict_s2["config"] @@ -342,15 +422,18 @@ def get_spepc(hps, filename): dict_language = { - "中文": "zh", + "中文": "all_zh", "英文": "en", - "日文": "ja", - "ZH": "zh", - "EN": "en", - "JA": "ja", - "zh": "zh", + "日文": "all_ja", + "中英混合": "zh", + "日英混合": "ja", + "多语种混合": "auto", #多语种启动切分识别语种 + "all_zh": "all_zh", "en": "en", - "ja": "ja" + "all_ja": "all_ja", + "zh": "zh", + "ja": "ja", + "auto": "auto", } @@ -374,25 +457,14 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) codes = vq_model.extract_latent(ssl_content) prompt_semantic = codes[0, 0] t1 = ttime() - prompt_language = dict_language[prompt_language] - text_language = dict_language[text_language] - phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language) - phones1 = cleaned_text_to_sequence(phones1) + prompt_language = dict_language[prompt_language.lower()] + text_language = dict_language[text_language.lower()] + phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language) texts = text.split("\n") audio_opt = [] for text in texts: - phones2, word2ph2, norm_text2 = clean_text(text, text_language) - phones2 = cleaned_text_to_sequence(phones2) - if (prompt_language == "zh"): - bert1 = get_bert_feature(norm_text1, word2ph1).to(device) - else: - bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to( - device) - if (text_language == "zh"): - bert2 = get_bert_feature(norm_text2, word2ph2).to(device) - else: - bert2 = torch.zeros((1024, len(phones2))).to(bert1) + phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language) bert = torch.cat([bert1, bert2], 1) all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)