mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-08 16:00:01 +08:00
Add zh/jp/en mix
This commit is contained in:
parent
7bc0836d99
commit
436032214a
116
api.py
116
api.py
@ -111,6 +111,7 @@ sys.path.append(now_dir)
|
|||||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||||
|
|
||||||
import signal
|
import signal
|
||||||
|
import LangSegment
|
||||||
from time import time as ttime
|
from time import time as ttime
|
||||||
import torch
|
import torch
|
||||||
import librosa
|
import librosa
|
||||||
@ -249,6 +250,8 @@ def change_sovits_weights(sovits_path):
|
|||||||
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||||
with open("./sweight.txt", "w", encoding="utf-8") as f:
|
with open("./sweight.txt", "w", encoding="utf-8") as f:
|
||||||
f.write(sovits_path)
|
f.write(sovits_path)
|
||||||
|
|
||||||
|
|
||||||
def change_gpt_weights(gpt_path):
|
def change_gpt_weights(gpt_path):
|
||||||
global hz, max_sec, t2s_model, config
|
global hz, max_sec, t2s_model, config
|
||||||
hz = 50
|
hz = 50
|
||||||
@ -283,6 +286,83 @@ def get_bert_feature(text, word2ph):
|
|||||||
return phone_level_feature.T
|
return phone_level_feature.T
|
||||||
|
|
||||||
|
|
||||||
|
def clean_text_inf(text, language):
|
||||||
|
phones, word2ph, norm_text = clean_text(text, language)
|
||||||
|
phones = cleaned_text_to_sequence(phones)
|
||||||
|
return phones, word2ph, norm_text
|
||||||
|
|
||||||
|
|
||||||
|
def get_bert_inf(phones, word2ph, norm_text, language):
|
||||||
|
language=language.replace("all_","")
|
||||||
|
if language == "zh":
|
||||||
|
bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
|
||||||
|
else:
|
||||||
|
bert = torch.zeros(
|
||||||
|
(1024, len(phones)),
|
||||||
|
dtype=torch.float16 if is_half == True else torch.float32,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
return bert
|
||||||
|
|
||||||
|
|
||||||
|
def get_phones_and_bert(text,language):
|
||||||
|
if language in {"en","all_zh","all_ja"}:
|
||||||
|
language = language.replace("all_","")
|
||||||
|
if language == "en":
|
||||||
|
LangSegment.setfilters(["en"])
|
||||||
|
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
|
||||||
|
else:
|
||||||
|
# 因无法区别中日文汉字,以用户输入为准
|
||||||
|
formattext = text
|
||||||
|
while " " in formattext:
|
||||||
|
formattext = formattext.replace(" ", " ")
|
||||||
|
phones, word2ph, norm_text = clean_text_inf(formattext, language)
|
||||||
|
if language == "zh":
|
||||||
|
bert = get_bert_feature(norm_text, word2ph).to(device)
|
||||||
|
else:
|
||||||
|
bert = torch.zeros(
|
||||||
|
(1024, len(phones)),
|
||||||
|
dtype=torch.float16 if is_half == True else torch.float32,
|
||||||
|
).to(device)
|
||||||
|
elif language in {"zh", "ja","auto"}:
|
||||||
|
textlist=[]
|
||||||
|
langlist=[]
|
||||||
|
LangSegment.setfilters(["zh","ja","en","ko"])
|
||||||
|
if language == "auto":
|
||||||
|
for tmp in LangSegment.getTexts(text):
|
||||||
|
if tmp["lang"] == "ko":
|
||||||
|
langlist.append("zh")
|
||||||
|
textlist.append(tmp["text"])
|
||||||
|
else:
|
||||||
|
langlist.append(tmp["lang"])
|
||||||
|
textlist.append(tmp["text"])
|
||||||
|
else:
|
||||||
|
for tmp in LangSegment.getTexts(text):
|
||||||
|
if tmp["lang"] == "en":
|
||||||
|
langlist.append(tmp["lang"])
|
||||||
|
else:
|
||||||
|
# 因无法区别中日文汉字,以用户输入为准
|
||||||
|
langlist.append(language)
|
||||||
|
textlist.append(tmp["text"])
|
||||||
|
print(textlist)
|
||||||
|
print(langlist)
|
||||||
|
phones_list = []
|
||||||
|
bert_list = []
|
||||||
|
norm_text_list = []
|
||||||
|
for i in range(len(textlist)):
|
||||||
|
lang = langlist[i]
|
||||||
|
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
|
||||||
|
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
||||||
|
phones_list.append(phones)
|
||||||
|
norm_text_list.append(norm_text)
|
||||||
|
bert_list.append(bert)
|
||||||
|
bert = torch.cat(bert_list, dim=1)
|
||||||
|
phones = sum(phones_list, [])
|
||||||
|
norm_text = ''.join(norm_text_list)
|
||||||
|
|
||||||
|
return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text
|
||||||
|
|
||||||
|
|
||||||
n_semantic = 1024
|
n_semantic = 1024
|
||||||
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
||||||
hps = dict_s2["config"]
|
hps = dict_s2["config"]
|
||||||
@ -342,15 +422,18 @@ def get_spepc(hps, filename):
|
|||||||
|
|
||||||
|
|
||||||
dict_language = {
|
dict_language = {
|
||||||
"中文": "zh",
|
"中文": "all_zh",
|
||||||
"英文": "en",
|
"英文": "en",
|
||||||
"日文": "ja",
|
"日文": "all_ja",
|
||||||
"ZH": "zh",
|
"中英混合": "zh",
|
||||||
"EN": "en",
|
"日英混合": "ja",
|
||||||
"JA": "ja",
|
"多语种混合": "auto", #多语种启动切分识别语种
|
||||||
"zh": "zh",
|
"all_zh": "all_zh",
|
||||||
"en": "en",
|
"en": "en",
|
||||||
"ja": "ja"
|
"all_ja": "all_ja",
|
||||||
|
"zh": "zh",
|
||||||
|
"ja": "ja",
|
||||||
|
"auto": "auto",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -374,25 +457,14 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
|||||||
codes = vq_model.extract_latent(ssl_content)
|
codes = vq_model.extract_latent(ssl_content)
|
||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
t1 = ttime()
|
t1 = ttime()
|
||||||
prompt_language = dict_language[prompt_language]
|
prompt_language = dict_language[prompt_language.lower()]
|
||||||
text_language = dict_language[text_language]
|
text_language = dict_language[text_language.lower()]
|
||||||
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
|
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language)
|
||||||
phones1 = cleaned_text_to_sequence(phones1)
|
|
||||||
texts = text.split("\n")
|
texts = text.split("\n")
|
||||||
audio_opt = []
|
audio_opt = []
|
||||||
|
|
||||||
for text in texts:
|
for text in texts:
|
||||||
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
|
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language)
|
||||||
phones2 = cleaned_text_to_sequence(phones2)
|
|
||||||
if (prompt_language == "zh"):
|
|
||||||
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
|
|
||||||
else:
|
|
||||||
bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to(
|
|
||||||
device)
|
|
||||||
if (text_language == "zh"):
|
|
||||||
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
|
|
||||||
else:
|
|
||||||
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
|
|
||||||
bert = torch.cat([bert1, bert2], 1)
|
bert = torch.cat([bert1, bert2], 1)
|
||||||
|
|
||||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user