Update inference_webui.py

Integrate inference webui to main window.
This commit is contained in:
Ke 2024-01-21 14:19:14 +08:00 committed by GitHub
parent a3435c036e
commit 8630a9d9eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,66 +1,17 @@
import os import os
gpt_path = os.environ.get(
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
)
sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
cnhubert_base_path = os.environ.get(
"cnhubert_base_path", "pretrained_models/chinese-hubert-base"
)
bert_path = os.environ.get(
"bert_path", "pretrained_models/chinese-roberta-wwm-ext-large"
)
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
infer_ttswebui = int(infer_ttswebui)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True"))
import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np import numpy as np
import librosa,torch import librosa,torch
from feature_extractor import cnhubert from feature_extractor import cnhubert
cnhubert.cnhubert_base_path=cnhubert_base_path
from module.models import SynthesizerTrn from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from transformers import AutoModelForMaskedLM, AutoTokenizer
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
from text.cleaner import clean_text from text.cleaner import clean_text
from time import time as ttime from time import time as ttime
from module.mel_processing import spectrogram_torch from module.mel_processing import spectrogram_torch
from my_utils import load_audio from my_utils import load_audio
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
if is_half == True:
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
# bert_model=bert_model.to(device)
def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题精度随bert_model
res = bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
# if(is_half==True):phone_level_feature=phone_level_feature.half()
return phone_level_feature.T
n_semantic = 1024
dict_s2=torch.load(sovits_path,map_location="cpu")
hps=dict_s2["config"]
class DictToAttrRecursive(dict): class DictToAttrRecursive(dict):
def __init__(self, input_dict): def __init__(self, input_dict):
@ -90,40 +41,205 @@ class DictToAttrRecursive(dict):
raise AttributeError(f"Attribute {item} not found") raise AttributeError(f"Attribute {item} not found")
hps = DictToAttrRecursive(hps) class Inference:
def __init__(self, is_half, GPT_weight_root, SoVITS_weight_root):
self.n_semantic = 1024
self.model_loaded = False
self.is_half = is_half
self.GPT_weight_root = GPT_weight_root
self.SoVITS_weight_root = SoVITS_weight_root
hps.model.semantic_frame_rate = "25hz"
dict_s1 = torch.load(gpt_path, map_location="cpu")
config = dict_s1["config"]
ssl_model = cnhubert.get_model()
if is_half == True:
ssl_model = ssl_model.half().to(device)
else:
ssl_model = ssl_model.to(device)
vq_model = SynthesizerTrn( def update_envs(self, gpt_path, sovits_path, cnhubert_base_path, bert_path):
hps.data.filter_length // 2 + 1, self.gpt_path = os.path.join(self.GPT_weight_root, gpt_path)
hps.train.segment_size // hps.data.hop_length, self.sovits_path = os.path.join(self.SoVITS_weight_root, sovits_path)
n_speakers=hps.data.n_speakers, self.cnhubert_base_path = cnhubert_base_path
**hps.model self.bert_path = bert_path
)
if is_half == True: cnhubert.cnhubert_base_path=cnhubert_base_path
vq_model = vq_model.half().to(device)
else: yield self.load_model()
vq_model = vq_model.to(device)
vq_model.eval() def load_model(self, device='cuda'):
print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) try:
hz = 50 # Load bert model
max_sec = config["data"]["max_sec"] self.device = device
# t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo self.tokenizer = AutoTokenizer.from_pretrained(self.bert_path)
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False) self.bert_model = AutoModelForMaskedLM.from_pretrained(self.bert_path)
t2s_model.load_state_dict(dict_s1["weight"]) if self.is_half == True:
if is_half == True: self.bert_model = self.bert_model.half().to(device)
t2s_model = t2s_model.half() else:
t2s_model = t2s_model.to(device) self.bert_model = self.bert_model.to(device)
t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()]) # Load ssl model
print("Number of parameter: %.2fM" % (total / 1e6)) dict_s1 = torch.load(self.gpt_path, map_location="cpu")
self.config = dict_s1["config"]
self.ssl_model = cnhubert.get_model()
if self.is_half == True:
self.ssl_model = self.ssl_model.half().to(device)
else:
self.ssl_model = self.ssl_model.to(device)
dict_s2=torch.load(self.sovits_path,map_location="cpu")
self.hps=dict_s2["config"]
self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz"
# Load vq model
self.vq_model = SynthesizerTrn(
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
n_speakers=self.hps.data.n_speakers,
**self.hps.model
)
if self.is_half == True:
self.vq_model = self.vq_model.half().to(device)
else:
self.vq_model = self.vq_model.to(device)
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
# Load t2s model
# t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
self.t2s_model.load_state_dict(dict_s1["weight"])
if self.is_half == True:
self.t2s_model = self.t2s_model.half()
self.t2s_model = self.t2s_model.to(device)
self.t2s_model.eval()
total = sum([param.nelement() for param in self.t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
self.model_loaded = True
return '模型加载成功'
except Exception as e:
return f'模型加载失败:{e}'
def unload_model(self):
if self.model_loaded:
try:
del self.bert_model, self.ssl_model, self.hps, self.vq_model, self.t2s_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.model_loaded = False
yield '模型卸载成功'
except Exception as e:
yield f'模型卸载失败:{e}'
else:
yield '模型未加载'
def get_bert_feature(self, text, word2ph):
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device) #####输入是long不用管精度问题精度随bert_model
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
# if(is_half==True):phone_level_feature=phone_level_feature.half()
return phone_level_feature.T
def get_tts_wav(self, ref_wav_path, prompt_text, prompt_language, text, text_language):
if not self.model_loaded:
return
hz = 50
dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
t0 = ttime()
prompt_text = prompt_text.strip("\n")
prompt_language, text = prompt_language, text.strip("\n")
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
wav16k = torch.from_numpy(wav16k)
if self.is_half == True:
wav16k = wav16k.half().to(self.device)
else:
wav16k = wav16k.to(self.device)
ssl_content = self.ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
codes = self.vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
phones1 = cleaned_text_to_sequence(phones1)
texts = text.split("\n")
audio_opt = []
zero_wav = np.zeros(
int(self.hps.data.sampling_rate * 0.3),
dtype=np.float16 if self.is_half == True else np.float32,
)
for text in texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
continue
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
phones2 = cleaned_text_to_sequence(phones2)
if prompt_language == "zh":
bert1 = self.get_bert_feature(norm_text1, word2ph1).to(self.device)
else:
bert1 = torch.zeros(
(1024, len(phones1)),
dtype=torch.float16 if self.is_half == True else torch.float32,
).to(self.device)
if text_language == "zh":
bert2 = self.get_bert_feature(norm_text2, word2ph2).to(self.device)
else:
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(self.device).unsqueeze(0)
bert = bert.to(self.device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(self.device)
prompt = prompt_semantic.unsqueeze(0).to(self.device)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = self.t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=self.config["inference"]["top_k"],
early_stop_num=hz * self.config["data"]["max_sec"]
)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
0
) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(self.hps, ref_wav_path) # .to(device)
if self.is_half == True:
refer = refer.half().to(self.device)
else:
refer = refer.to(self.device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = (
self.vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(self.device).unsqueeze(0), refer
)
.detach()
.cpu()
.numpy()[0, 0]
) ###试试重建不带上prompt部分
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
yield self.hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
np.int16
)
def get_spepc(hps, filename): def get_spepc(hps, filename):
@ -142,119 +258,9 @@ def get_spepc(hps, filename):
return spec return spec
dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
t0 = ttime()
prompt_text = prompt_text.strip("\n")
prompt_language, text = prompt_language, text.strip("\n")
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
wav16k = torch.from_numpy(wav16k)
if is_half == True:
wav16k = wav16k.half().to(device)
else:
wav16k = wav16k.to(device)
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
phones1 = cleaned_text_to_sequence(phones1)
texts = text.split("\n")
audio_opt = []
zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
dtype=np.float16 if is_half == True else np.float32,
)
for text in texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
continue
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
phones2 = cleaned_text_to_sequence(phones2)
if prompt_language == "zh":
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
else:
bert1 = torch.zeros(
(1024, len(phones1)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
if text_language == "zh":
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
else:
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=config["inference"]["top_k"],
early_stop_num=hz * max_sec,
)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
0
) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if is_half == True:
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = (
vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
)
.detach()
.cpu()
.numpy()[0, 0]
) ###试试重建不带上prompt部分
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
np.int16
)
splits = {
"",
"",
"",
"",
",",
".",
"?",
"!",
"~",
":",
"",
"",
"",
} # 不考虑省略号
def split(todo_text): def split(todo_text):
splits = { "", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", } # 不考虑省略号
todo_text = todo_text.replace("……", "").replace("——", "") todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits: if todo_text[-1] not in splits:
todo_text += "" todo_text += ""
@ -314,50 +320,3 @@ def cut3(inp):
inp = inp.strip("\n") inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip("").split("")]) return "\n".join(["%s" % item for item in inp.strip("").split("")])
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
value="本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
)
# with gr.Tabs():
# with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
with gr.Group():
gr.Markdown(value="*请上传并填写参考信息")
with gr.Row():
inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
prompt_text = gr.Textbox(label="参考音频的文本", value="")
prompt_language = gr.Dropdown(
label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文"
)
gr.Markdown(value="*请填写需要合成的目标文本")
with gr.Row():
text = gr.Textbox(label="需要合成的文本", value="")
text_language = gr.Dropdown(
label="需要合成的语种", choices=["中文", "英文", "日文"], value="中文"
)
inference_button = gr.Button("合成语音", variant="primary")
output = gr.Audio(label="输出的语音")
inference_button.click(
get_tts_wav,
[inp_ref, prompt_text, prompt_language, text, text_language],
[output],
)
gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
with gr.Row():
text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
button1 = gr.Button("凑五句一切", variant="primary")
button2 = gr.Button("凑50字一切", variant="primary")
button3 = gr.Button("按中文句号。切", variant="primary")
text_opt = gr.Textbox(label="切分后文本", value="")
button1.click(cut1, [text_inp], [text_opt])
button2.click(cut2, [text_inp], [text_opt])
button3.click(cut3, [text_inp], [text_opt])
gr.Markdown(value="后续将支持混合语种编码文本输入。")
app.queue(concurrency_count=511, max_size=1022).launch(
server_name="0.0.0.0",
inbrowser=True,
server_port=infer_ttswebui,
quiet=True,
)