mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 07:02:57 +08:00
added support for mixed language API according to the current implementation in inference_webui.py
This commit is contained in:
parent
0ab0e5390f
commit
07c620c17e
110
api.py
110
api.py
@ -80,6 +80,23 @@ RESP:
|
||||
失败: json, 400
|
||||
|
||||
|
||||
### 动态更换底模
|
||||
|
||||
endpoint: `/set_model`
|
||||
|
||||
GET:
|
||||
`http://127.0.0.1:9880/set_model?gpt_model_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt&sovits_model_path=GPT_SoVITS/pretrained_models/s2G488k.pth`
|
||||
POST:
|
||||
```json
|
||||
{
|
||||
"gpt_model_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||
"sovits_model_path": "GPT_SoVITS/pretrained_models/s2G488k.pth"
|
||||
}
|
||||
```
|
||||
|
||||
RESP:
|
||||
成功: json, http code 200
|
||||
|
||||
### 命令控制
|
||||
|
||||
endpoint: `/control`
|
||||
@ -126,6 +143,7 @@ from module.models import SynthesizerTrn
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from text import cleaned_text_to_sequence
|
||||
from text.cleaner import clean_text
|
||||
import LangSegment
|
||||
from module.mel_processing import spectrogram_torch
|
||||
from my_utils import load_audio
|
||||
import config as global_config
|
||||
@ -350,9 +368,84 @@ dict_language = {
|
||||
"JA": "ja",
|
||||
"zh": "zh",
|
||||
"en": "en",
|
||||
"ja": "ja"
|
||||
"ja": "ja",
|
||||
"auto": "auto",
|
||||
"中英混合": "zh",
|
||||
"日英混合": "ja",
|
||||
"多语种混合": "auto"
|
||||
}
|
||||
|
||||
dtype=torch.float16 if is_half == True else torch.float32
|
||||
def get_bert_inf(phones, word2ph, norm_text, language):
|
||||
language=language.replace("all_","")
|
||||
if language == "zh":
|
||||
bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
|
||||
else:
|
||||
bert = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float16 if is_half == True else torch.float32,
|
||||
).to(device)
|
||||
|
||||
return bert
|
||||
|
||||
def clean_text_inf(text, language):
|
||||
phones, word2ph, norm_text = clean_text(text, language)
|
||||
phones = cleaned_text_to_sequence(phones)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
# Mixed language support
|
||||
def get_phones_and_bert(text,language):
|
||||
if language in {"en","all_zh","all_ja"}:
|
||||
language = language.replace("all_","")
|
||||
if language == "en":
|
||||
LangSegment.setfilters(["en"])
|
||||
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
|
||||
else:
|
||||
# 因无法区别中日文汉字,以用户输入为准
|
||||
formattext = text
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
phones, word2ph, norm_text = clean_text_inf(formattext, language)
|
||||
if language == "zh":
|
||||
bert = get_bert_feature(norm_text, word2ph).to(device)
|
||||
else:
|
||||
bert = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float16 if is_half == True else torch.float32,
|
||||
).to(device)
|
||||
elif language in {"zh", "ja","auto"}:
|
||||
textlist=[]
|
||||
langlist=[]
|
||||
LangSegment.setfilters(["zh","ja","en"])
|
||||
if language == "auto":
|
||||
for tmp in LangSegment.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegment.getTexts(text):
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
# 因无法区别中日文汉字,以用户输入为准
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
print(textlist)
|
||||
print(langlist)
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
for i in range(len(textlist)):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
|
||||
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = ''.join(norm_text_list)
|
||||
|
||||
return phones,bert.to(dtype),norm_text
|
||||
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
|
||||
t0 = ttime()
|
||||
@ -380,22 +473,13 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
||||
phones1 = cleaned_text_to_sequence(phones1)
|
||||
texts = text.split("\n")
|
||||
audio_opt = []
|
||||
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
|
||||
|
||||
for text in texts:
|
||||
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
|
||||
phones2 = cleaned_text_to_sequence(phones2)
|
||||
if (prompt_language == "zh"):
|
||||
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
|
||||
else:
|
||||
bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to(
|
||||
device)
|
||||
if (text_language == "zh"):
|
||||
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
|
||||
else:
|
||||
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
|
||||
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
|
||||
|
||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||
bert = bert.to(device).unsqueeze(0)
|
||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user