From 07c620c17ee69cd937b6649eef9c534aec38b48e Mon Sep 17 00:00:00 2001 From: zih-an Date: Fri, 1 Mar 2024 14:38:35 +0000 Subject: [PATCH] added support for mixed language API according to the current implementation in inference_webui.py --- api.py | 110 ++++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 97 insertions(+), 13 deletions(-) diff --git a/api.py b/api.py index 754f0769..993f090b 100644 --- a/api.py +++ b/api.py @@ -80,6 +80,23 @@ RESP: 失败: json, 400 +### 动态更换底模 + +endpoint: `/set_model` + +GET: + `http://127.0.0.1:9880/set_model?gpt_model_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt&sovits_model_path=GPT_SoVITS/pretrained_models/s2G488k.pth` +POST: +```json +{ + "gpt_model_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", + "sovits_model_path": "GPT_SoVITS/pretrained_models/s2G488k.pth" +} +``` + +RESP: +成功: json, http code 200 + ### 命令控制 endpoint: `/control` @@ -126,6 +143,7 @@ from module.models import SynthesizerTrn from AR.models.t2s_lightning_module import Text2SemanticLightningModule from text import cleaned_text_to_sequence from text.cleaner import clean_text +import LangSegment from module.mel_processing import spectrogram_torch from my_utils import load_audio import config as global_config @@ -350,9 +368,84 @@ dict_language = { "JA": "ja", "zh": "zh", "en": "en", - "ja": "ja" + "ja": "ja", + "auto": "auto", + "中英混合": "zh", + "日英混合": "ja", + "多语种混合": "auto" } +dtype=torch.float16 if is_half == True else torch.float32 +def get_bert_inf(phones, word2ph, norm_text, language): + language=language.replace("all_","") + if language == "zh": + bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype) + else: + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) + + return bert + +def clean_text_inf(text, language): + phones, word2ph, norm_text = clean_text(text, language) + phones = cleaned_text_to_sequence(phones) + return phones, word2ph, norm_text + +# Mixed language support +def get_phones_and_bert(text,language): + if language in {"en","all_zh","all_ja"}: + language = language.replace("all_","") + if language == "en": + LangSegment.setfilters(["en"]) + formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text)) + else: + # 因无法区别中日文汉字,以用户输入为准 + formattext = text + while " " in formattext: + formattext = formattext.replace(" ", " ") + phones, word2ph, norm_text = clean_text_inf(formattext, language) + if language == "zh": + bert = get_bert_feature(norm_text, word2ph).to(device) + else: + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) + elif language in {"zh", "ja","auto"}: + textlist=[] + langlist=[] + LangSegment.setfilters(["zh","ja","en"]) + if language == "auto": + for tmp in LangSegment.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegment.getTexts(text): + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + # 因无法区别中日文汉字,以用户输入为准 + langlist.append(language) + textlist.append(tmp["text"]) + print(textlist) + print(langlist) + phones_list = [] + bert_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = clean_text_inf(textlist[i], lang) + bert = get_bert_inf(phones, word2ph, norm_text, lang) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = ''.join(norm_text_list) + + return phones,bert.to(dtype),norm_text def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): t0 = ttime() @@ -380,22 +473,13 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) phones1 = cleaned_text_to_sequence(phones1) texts = text.split("\n") audio_opt = [] + phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language) for text in texts: - phones2, word2ph2, norm_text2 = clean_text(text, text_language) - phones2 = cleaned_text_to_sequence(phones2) - if (prompt_language == "zh"): - bert1 = get_bert_feature(norm_text1, word2ph1).to(device) - else: - bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to( - device) - if (text_language == "zh"): - bert2 = get_bert_feature(norm_text2, word2ph2).to(device) - else: - bert2 = torch.zeros((1024, len(phones2))).to(bert1) + phones2,bert2,norm_text2=get_phones_and_bert(text, text_language) bert = torch.cat([bert1, bert2], 1) + all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0) - all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) bert = bert.to(device).unsqueeze(0) all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) prompt = prompt_semantic.unsqueeze(0).to(device)