added support for mixed language API according to the current implementation in inference_webui.py

This commit is contained in:
zih-an 2024-03-01 14:38:35 +00:00
parent 0ab0e5390f
commit 07c620c17e

110
api.py
View File

@ -80,6 +80,23 @@ RESP:
失败: json, 400
### 动态更换底模
endpoint: `/set_model`
GET:
`http://127.0.0.1:9880/set_model?gpt_model_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt&sovits_model_path=GPT_SoVITS/pretrained_models/s2G488k.pth`
POST:
```json
{
"gpt_model_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"sovits_model_path": "GPT_SoVITS/pretrained_models/s2G488k.pth"
}
```
RESP:
成功: json, http code 200
### 命令控制
endpoint: `/control`
@ -126,6 +143,7 @@ from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence
from text.cleaner import clean_text
import LangSegment
from module.mel_processing import spectrogram_torch
from my_utils import load_audio
import config as global_config
@ -350,9 +368,84 @@ dict_language = {
"JA": "ja",
"zh": "zh",
"en": "en",
"ja": "ja"
"ja": "ja",
"auto": "auto",
"中英混合": "zh",
"日英混合": "ja",
"多语种混合": "auto"
}
dtype=torch.float16 if is_half == True else torch.float32
def get_bert_inf(phones, word2ph, norm_text, language):
language=language.replace("all_","")
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
else:
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
return bert
def clean_text_inf(text, language):
phones, word2ph, norm_text = clean_text(text, language)
phones = cleaned_text_to_sequence(phones)
return phones, word2ph, norm_text
# Mixed language support
def get_phones_and_bert(text,language):
if language in {"en","all_zh","all_ja"}:
language = language.replace("all_","")
if language == "en":
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
else:
# 因无法区别中日文汉字,以用户输入为准
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
phones, word2ph, norm_text = clean_text_inf(formattext, language)
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)
else:
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
elif language in {"zh", "ja","auto"}:
textlist=[]
langlist=[]
LangSegment.setfilters(["zh","ja","en"])
if language == "auto":
for tmp in LangSegment.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
print(textlist)
print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
bert = get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
return phones,bert.to(dtype),norm_text
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
t0 = ttime()
@ -380,22 +473,13 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
phones1 = cleaned_text_to_sequence(phones1)
texts = text.split("\n")
audio_opt = []
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
for text in texts:
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
phones2 = cleaned_text_to_sequence(phones2)
if (prompt_language == "zh"):
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
else:
bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to(
device)
if (text_language == "zh"):
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
else:
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)