新增VITS批量推理 GPT_SoVITS/TTS_infer_pack/TTS.py

fix some bugs   GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
	fix some bugs   GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py
	fix some bugs   GPT_SoVITS/inference_webui.py
	fix some bugs   GPT_SoVITS/module/models.py
This commit is contained in:
chasonjiang 2024-03-10 21:37:28 +08:00
parent 174c4bbab3
commit 3535cfe3b0
5 changed files with 182 additions and 44 deletions

View File

@ -1,3 +1,4 @@
import math
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
@ -366,6 +367,7 @@ class TTS:
for batch_idx, index_list in enumerate(batch_index_list):
item_list = [data[idx] for idx in index_list]
phones_list = []
phones_len_list = []
# bert_features_list = []
all_phones_list = []
all_phones_len_list = []
@ -375,24 +377,26 @@ class TTS:
phones_max_len = 0
for item in item_list:
if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"].clone(), item["bert_features"]], 1)
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
phones = torch.LongTensor(item["phones"])
# norm_text = prompt_data["norm_text"]+item["norm_text"]
else:
all_bert_features = item["bert_features"]
phones = torch.LongTensor(item["phones"])
all_phones = phones.clone()
all_phones = phones
# norm_text = item["norm_text"]
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
phones_max_len = max(phones_max_len, phones.shape[-1])
phones_list.append(phones)
phones_len_list.append(phones.shape[-1])
all_phones_list.append(all_phones)
all_phones_len_list.append(all_phones.shape[-1])
all_bert_features_list.append(all_bert_features)
norm_text_batch.append(item["norm_text"])
phones_batch = phones_list
max_len = max(bert_max_len, phones_max_len)
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
@ -406,6 +410,7 @@ class TTS:
batch = {
"phones": phones_batch,
"phones_len": torch.LongTensor(phones_len_list),
"all_phones": all_phones_batch,
"all_phones_len": torch.LongTensor(all_phones_len_list),
"all_bert_features": all_bert_features_batch,
@ -491,7 +496,11 @@ class TTS:
if split_bucket:
print(i18n("分桶处理模式已开启"))
# if vits_batched_inference:
# print(i18n("VITS批量推理模式已开启"))
# else:
# print(i18n("VITS单句推理模式已开启"))
no_prompt_text = False
if prompt_text in [None, ""]:
@ -529,7 +538,6 @@ class TTS:
###### text preprocessing ########
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
audio = []
t1 = ttime()
data, batch_index_list = self.to_batch(data,
prompt_data=self.prompt_cache if not no_prompt_text else None,
@ -538,24 +546,23 @@ class TTS:
split_bucket=split_bucket
)
t2 = ttime()
zero_wav = torch.zeros(
int(self.configs.sampling_rate * 0.3),
dtype=torch.float16 if self.configs.is_half else torch.float32,
device=self.configs.device
)
###### inference ######
t_34 = 0.0
t_45 = 0.0
audio = []
for item in data:
t3 = ttime()
batch_phones = item["phones"]
batch_phones_len = item["phones_len"]
all_phoneme_ids = item["all_phones"]
all_phoneme_lens = item["all_phones_len"]
all_bert_features = item["all_bert_features"]
norm_text = item["norm_text"]
# batch_phones = batch_phones.to(self.configs.device)
batch_phones_len = batch_phones_len.to(self.configs.device)
all_phoneme_ids = all_phoneme_ids.to(self.configs.device)
all_phoneme_lens = all_phoneme_lens.to(self.configs.device)
all_bert_features = all_bert_features.to(self.configs.device)
@ -566,7 +573,7 @@ class TTS:
if no_prompt_text :
prompt = None
else:
prompt = self.prompt_cache["prompt_semantic"].clone().repeat(all_phoneme_ids.shape[0], 1).to(self.configs.device)
prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device)
with torch.no_grad():
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
@ -583,41 +590,52 @@ class TTS:
t4 = ttime()
t_34 += t4 - t3
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].clone().to(self.configs.device)
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].to(self.configs.device)
if self.configs.is_half:
refer_audio_spepc = refer_audio_spepc.half()
## 直接对batch进行decode 生成的音频会有问题
batch_audio_fragment = []
# ## vits并行推理 method 1
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
# batch_phones = batch_phones.to(self.configs.device)
# batch_audio_fragment =(self.vits_model.decode(
# pred_semantic, batch_phones, refer_audio_spepc
# ).detach()[:, 0, :])
# max_audio=torch.abs(batch_audio_fragment).max()#简单防止16bit爆音
# if max_audio>1: batch_audio_fragment/=max_audio
# batch_audio_fragment = batch_audio_fragment.cpu().numpy()
# batch_audio_fragment = (self.vits_model.batched_decode(
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc
# ))
## 改成串行处理
batch_audio_fragment = []
for i, idx in enumerate(idx_list):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment =(self.vits_model.decode(
_pred_semantic, phones, refer_audio_spepc
).detach()[0, 0, :])
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
if max_audio>1: audio_fragment/=max_audio
audio_fragment = torch.cat([audio_fragment, zero_wav], dim=0)
batch_audio_fragment.append(
audio_fragment.cpu().numpy()
) ###试试重建不带上prompt部分
# ## vits并行推理 method 2
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
upsample_rate = math.prod(self.vits_model.upsample_rates)
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
_batch_audio_fragment = (self.vits_model.decode(
all_pred_semantic, _batch_phones,refer_audio_spepc
).detach()[0, 0, :])
audio_frag_end_idx.insert(0, 0)
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
# ## vits串行推理
# for i, idx in enumerate(idx_list):
# phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
# _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
# audio_fragment =(self.vits_model.decode(
# _pred_semantic, phones, refer_audio_spepc
# ).detach()[0, 0, :])
# batch_audio_fragment.append(
# audio_fragment
# ) ###试试重建不带上prompt部分
t5 = ttime()
t_45 += t5 - t4
if return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
yield self.audio_postprocess(batch_audio_fragment,
yield self.audio_postprocess([batch_audio_fragment],
self.configs.sampling_rate,
batch_index_list,
speed_factor,
@ -626,7 +644,8 @@ class TTS:
audio.append(batch_audio_fragment)
if self.stop_flag:
yield self.configs.sampling_rate, (zero_wav.cpu().numpy()).astype(np.int16)
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
dtype=np.int16)
return
if not return_fragment:
@ -640,15 +659,30 @@ class TTS:
def audio_postprocess(self,
audio:np.ndarray,
audio:List[torch.Tensor],
sr:int,
batch_index_list:list=None,
speed_factor:float=1.0,
split_bucket:bool=True)->tuple[int, np.ndarray]:
zero_wav = torch.zeros(
int(self.configs.sampling_rate * 0.3),
dtype=torch.float16 if self.configs.is_half else torch.float32,
device=self.configs.device
)
for i, batch in enumerate(audio):
for j, audio_fragment in enumerate(batch):
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
if max_audio>1: audio_fragment/=max_audio
audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
audio[i][j] = audio_fragment.cpu().numpy()
if split_bucket:
audio = self.recovery_order(audio, batch_index_list)
else:
audio = [item for batch in audio for item in batch]
# audio = [item for batch in audio for item in batch]
audio = sum(audio, [])
audio = np.concatenate(audio, 0)

View File

@ -10,7 +10,7 @@ from typing import Dict, List, Tuple
from text.cleaner import clean_text
from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import splits, get_method as get_seg_method
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
# from tools.i18n.i18n import I18nAuto
@ -39,6 +39,10 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
return result
class TextPreprocessor:
def __init__(self, bert_model:AutoModelForMaskedLM,
tokenizer:AutoTokenizer, device:torch.device):
@ -74,12 +78,18 @@ class TextPreprocessor:
_texts = text.split("\n")
_texts = merge_short_text_in_array(_texts, 5)
texts = []
for text in _texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
continue
if (text[-1] not in splits): text += "" if lang != "en" else "."
texts.append(text)
# 解决句子过长导致Bert报错的问题
texts.extend(split_big_text(text))
return texts
@ -176,4 +186,8 @@ class TextPreprocessor:
dtype=torch.float32,
).to(self.device)
return feature
return feature

View File

@ -24,6 +24,32 @@ def register_method(name):
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
def split_big_text(text, max_len=510):
# 定义全角和半角标点符号
punctuation = "".join(splits)
# 切割文本
segments = re.split('([' + punctuation + '])', text)
# 初始化结果列表和当前片段
result = []
current_segment = ''
for segment in segments:
# 如果当前片段加上新的片段长度超过max_len就将当前片段加入结果列表并重置当前片段
if len(current_segment + segment) > max_len:
result.append(current_segment)
current_segment = segment
else:
current_segment += segment
# 将最后一个片段加入结果列表
if current_segment:
result.append(current_segment)
return result
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
@ -121,6 +147,6 @@ def cut5(inp):
if __name__ == '__main__':
method = get_method("cut1")
method = get_method("cut5")
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))

View File

@ -29,9 +29,13 @@ is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and not torch.backends.mps.is_available()
gpt_path = os.environ.get("gpt_path", None)
sovits_path = os.environ.get("sovits_path", None)
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
bert_path = os.environ.get("bert_path", None)
import gradio as gr
from TTS_infer_pack.TTS import TTS, TTS_Config
from TTS_infer_pack.text_segmentation_method import cut1, cut2, cut3, cut4, cut5
from TTS_infer_pack.text_segmentation_method import get_method
from tools.i18n.i18n import I18nAuto
@ -65,6 +69,15 @@ cut_method = {
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
tts_config.device = device
tts_config.is_half = is_half
if gpt_path is not None:
tts_config.t2s_weights_path = gpt_path
if sovits_path is not None:
tts_config.vits_weights_path = sovits_path
if cnhubert_base_path is not None:
tts_config.cnhuhbert_base_path = cnhubert_base_path
if bert_path is not None:
tts_config.bert_base_path = bert_path
tts_pipline = TTS(tts_config)
gpt_path = tts_config.t2s_weights_path
sovits_path = tts_config.vits_weights_path
@ -169,7 +182,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
with gr.Row():
with gr.Column():
batch_size = gr.Slider(minimum=1,maximum=20,step=1,label=i18n("batch_size"),value=1,interactive=True)
batch_size = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("batch_size"),value=20,interactive=True)
speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True)
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
@ -181,7 +194,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
value=i18n("凑四句一切"),
interactive=True,
)
split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True)
with gr.Row():
split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True)
# with gr.Column():
output = gr.Audio(label=i18n("输出的语音"))
with gr.Row():

View File

@ -1,5 +1,6 @@
import copy
import math
from typing import List
import torch
from torch import nn
from torch.nn import functional as F
@ -985,6 +986,55 @@ class SynthesizerTrn(nn.Module):
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o
@torch.no_grad()
def batched_decode(self, codes, y_lengths, text, text_lengths, refer, noise_scale=0.5):
ge = None
if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(
commons.sequence_mask(refer_lengths, refer.size(2)), 1
).to(refer.dtype)
ge = self.ref_enc(refer * refer_mask, refer_mask)
# y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, codes.size(2)), 1).to(
# codes.dtype
# )
y_lengths = (y_lengths * 2).long().to(codes.device)
text_lengths = text_lengths.long().to(text.device)
# y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
# text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
# 假设padding之后再decode没有问题, 影响未知,但听起来好像没问题?
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
z_masked = (z * y_mask)[:, :, :]
# 串行。把padding部分去掉再decode
o_list:List[torch.Tensor] = []
for i in range(z_masked.shape[0]):
z_slice = z_masked[i, :, :y_lengths[i]].unsqueeze(0)
o = self.dec(z_slice, g=ge)[0, 0, :].detach()
o_list.append(o)
# 并行会有问题。先decode再把padding的部分去掉
# o = self.dec(z_masked, g=ge)
# upsample_rate = int(math.prod(self.upsample_rates))
# o_lengths = y_lengths*upsample_rate
# o_list = [o[i, 0, :idx].detach() for i, idx in enumerate(o_lengths)]
return o_list
def extract_latent(self, x):
ssl = self.ssl_proj(x)