mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
新增VITS批量推理 GPT_SoVITS/TTS_infer_pack/TTS.py
fix some bugs GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py fix some bugs GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py fix some bugs GPT_SoVITS/inference_webui.py fix some bugs GPT_SoVITS/module/models.py
This commit is contained in:
parent
174c4bbab3
commit
3535cfe3b0
@ -1,3 +1,4 @@
|
||||
import math
|
||||
import os, sys
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
@ -366,6 +367,7 @@ class TTS:
|
||||
for batch_idx, index_list in enumerate(batch_index_list):
|
||||
item_list = [data[idx] for idx in index_list]
|
||||
phones_list = []
|
||||
phones_len_list = []
|
||||
# bert_features_list = []
|
||||
all_phones_list = []
|
||||
all_phones_len_list = []
|
||||
@ -375,24 +377,26 @@ class TTS:
|
||||
phones_max_len = 0
|
||||
for item in item_list:
|
||||
if prompt_data is not None:
|
||||
all_bert_features = torch.cat([prompt_data["bert_features"].clone(), item["bert_features"]], 1)
|
||||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)
|
||||
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
|
||||
phones = torch.LongTensor(item["phones"])
|
||||
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
||||
else:
|
||||
all_bert_features = item["bert_features"]
|
||||
phones = torch.LongTensor(item["phones"])
|
||||
all_phones = phones.clone()
|
||||
all_phones = phones
|
||||
# norm_text = item["norm_text"]
|
||||
|
||||
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
|
||||
phones_max_len = max(phones_max_len, phones.shape[-1])
|
||||
|
||||
phones_list.append(phones)
|
||||
phones_len_list.append(phones.shape[-1])
|
||||
all_phones_list.append(all_phones)
|
||||
all_phones_len_list.append(all_phones.shape[-1])
|
||||
all_bert_features_list.append(all_bert_features)
|
||||
norm_text_batch.append(item["norm_text"])
|
||||
|
||||
phones_batch = phones_list
|
||||
max_len = max(bert_max_len, phones_max_len)
|
||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
@ -406,6 +410,7 @@ class TTS:
|
||||
|
||||
batch = {
|
||||
"phones": phones_batch,
|
||||
"phones_len": torch.LongTensor(phones_len_list),
|
||||
"all_phones": all_phones_batch,
|
||||
"all_phones_len": torch.LongTensor(all_phones_len_list),
|
||||
"all_bert_features": all_bert_features_batch,
|
||||
@ -491,7 +496,11 @@ class TTS:
|
||||
|
||||
if split_bucket:
|
||||
print(i18n("分桶处理模式已开启"))
|
||||
|
||||
|
||||
# if vits_batched_inference:
|
||||
# print(i18n("VITS批量推理模式已开启"))
|
||||
# else:
|
||||
# print(i18n("VITS单句推理模式已开启"))
|
||||
|
||||
no_prompt_text = False
|
||||
if prompt_text in [None, ""]:
|
||||
@ -529,7 +538,6 @@ class TTS:
|
||||
|
||||
###### text preprocessing ########
|
||||
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
|
||||
audio = []
|
||||
t1 = ttime()
|
||||
data, batch_index_list = self.to_batch(data,
|
||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||
@ -538,24 +546,23 @@ class TTS:
|
||||
split_bucket=split_bucket
|
||||
)
|
||||
t2 = ttime()
|
||||
zero_wav = torch.zeros(
|
||||
int(self.configs.sampling_rate * 0.3),
|
||||
dtype=torch.float16 if self.configs.is_half else torch.float32,
|
||||
device=self.configs.device
|
||||
)
|
||||
|
||||
|
||||
###### inference ######
|
||||
t_34 = 0.0
|
||||
t_45 = 0.0
|
||||
audio = []
|
||||
for item in data:
|
||||
t3 = ttime()
|
||||
batch_phones = item["phones"]
|
||||
batch_phones_len = item["phones_len"]
|
||||
all_phoneme_ids = item["all_phones"]
|
||||
all_phoneme_lens = item["all_phones_len"]
|
||||
all_bert_features = item["all_bert_features"]
|
||||
norm_text = item["norm_text"]
|
||||
|
||||
# batch_phones = batch_phones.to(self.configs.device)
|
||||
batch_phones_len = batch_phones_len.to(self.configs.device)
|
||||
all_phoneme_ids = all_phoneme_ids.to(self.configs.device)
|
||||
all_phoneme_lens = all_phoneme_lens.to(self.configs.device)
|
||||
all_bert_features = all_bert_features.to(self.configs.device)
|
||||
@ -566,7 +573,7 @@ class TTS:
|
||||
if no_prompt_text :
|
||||
prompt = None
|
||||
else:
|
||||
prompt = self.prompt_cache["prompt_semantic"].clone().repeat(all_phoneme_ids.shape[0], 1).to(self.configs.device)
|
||||
prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||
@ -583,41 +590,52 @@ class TTS:
|
||||
t4 = ttime()
|
||||
t_34 += t4 - t3
|
||||
|
||||
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].clone().to(self.configs.device)
|
||||
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].to(self.configs.device)
|
||||
if self.configs.is_half:
|
||||
refer_audio_spepc = refer_audio_spepc.half()
|
||||
|
||||
## 直接对batch进行decode 生成的音频会有问题
|
||||
|
||||
|
||||
batch_audio_fragment = []
|
||||
|
||||
# ## vits并行推理 method 1
|
||||
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
|
||||
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
|
||||
# batch_phones = batch_phones.to(self.configs.device)
|
||||
# batch_audio_fragment =(self.vits_model.decode(
|
||||
# pred_semantic, batch_phones, refer_audio_spepc
|
||||
# ).detach()[:, 0, :])
|
||||
# max_audio=torch.abs(batch_audio_fragment).max()#简单防止16bit爆音
|
||||
# if max_audio>1: batch_audio_fragment/=max_audio
|
||||
# batch_audio_fragment = batch_audio_fragment.cpu().numpy()
|
||||
# batch_audio_fragment = (self.vits_model.batched_decode(
|
||||
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc
|
||||
# ))
|
||||
|
||||
## 改成串行处理
|
||||
batch_audio_fragment = []
|
||||
for i, idx in enumerate(idx_list):
|
||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
audio_fragment =(self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spepc
|
||||
).detach()[0, 0, :])
|
||||
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
|
||||
if max_audio>1: audio_fragment/=max_audio
|
||||
audio_fragment = torch.cat([audio_fragment, zero_wav], dim=0)
|
||||
batch_audio_fragment.append(
|
||||
audio_fragment.cpu().numpy()
|
||||
) ###试试重建不带上prompt部分
|
||||
# ## vits并行推理 method 2
|
||||
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
||||
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
|
||||
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
|
||||
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||
_batch_audio_fragment = (self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones,refer_audio_spepc
|
||||
).detach()[0, 0, :])
|
||||
audio_frag_end_idx.insert(0, 0)
|
||||
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
|
||||
|
||||
|
||||
# ## vits串行推理
|
||||
# for i, idx in enumerate(idx_list):
|
||||
# phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||
# _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
# audio_fragment =(self.vits_model.decode(
|
||||
# _pred_semantic, phones, refer_audio_spepc
|
||||
# ).detach()[0, 0, :])
|
||||
# batch_audio_fragment.append(
|
||||
# audio_fragment
|
||||
# ) ###试试重建不带上prompt部分
|
||||
|
||||
t5 = ttime()
|
||||
t_45 += t5 - t4
|
||||
if return_fragment:
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
|
||||
yield self.audio_postprocess(batch_audio_fragment,
|
||||
yield self.audio_postprocess([batch_audio_fragment],
|
||||
self.configs.sampling_rate,
|
||||
batch_index_list,
|
||||
speed_factor,
|
||||
@ -626,7 +644,8 @@ class TTS:
|
||||
audio.append(batch_audio_fragment)
|
||||
|
||||
if self.stop_flag:
|
||||
yield self.configs.sampling_rate, (zero_wav.cpu().numpy()).astype(np.int16)
|
||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
|
||||
dtype=np.int16)
|
||||
return
|
||||
|
||||
if not return_fragment:
|
||||
@ -640,15 +659,30 @@ class TTS:
|
||||
|
||||
|
||||
def audio_postprocess(self,
|
||||
audio:np.ndarray,
|
||||
audio:List[torch.Tensor],
|
||||
sr:int,
|
||||
batch_index_list:list=None,
|
||||
speed_factor:float=1.0,
|
||||
split_bucket:bool=True)->tuple[int, np.ndarray]:
|
||||
zero_wav = torch.zeros(
|
||||
int(self.configs.sampling_rate * 0.3),
|
||||
dtype=torch.float16 if self.configs.is_half else torch.float32,
|
||||
device=self.configs.device
|
||||
)
|
||||
|
||||
for i, batch in enumerate(audio):
|
||||
for j, audio_fragment in enumerate(batch):
|
||||
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
|
||||
if max_audio>1: audio_fragment/=max_audio
|
||||
audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
|
||||
audio[i][j] = audio_fragment.cpu().numpy()
|
||||
|
||||
|
||||
if split_bucket:
|
||||
audio = self.recovery_order(audio, batch_index_list)
|
||||
else:
|
||||
audio = [item for batch in audio for item in batch]
|
||||
# audio = [item for batch in audio for item in batch]
|
||||
audio = sum(audio, [])
|
||||
|
||||
|
||||
audio = np.concatenate(audio, 0)
|
||||
|
@ -10,7 +10,7 @@ from typing import Dict, List, Tuple
|
||||
from text.cleaner import clean_text
|
||||
from text import cleaned_text_to_sequence
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
from TTS_infer_pack.text_segmentation_method import splits, get_method as get_seg_method
|
||||
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||
|
||||
# from tools.i18n.i18n import I18nAuto
|
||||
|
||||
@ -39,6 +39,10 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class TextPreprocessor:
|
||||
def __init__(self, bert_model:AutoModelForMaskedLM,
|
||||
tokenizer:AutoTokenizer, device:torch.device):
|
||||
@ -74,12 +78,18 @@ class TextPreprocessor:
|
||||
_texts = text.split("\n")
|
||||
_texts = merge_short_text_in_array(_texts, 5)
|
||||
texts = []
|
||||
|
||||
|
||||
|
||||
for text in _texts:
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if (len(text.strip()) == 0):
|
||||
continue
|
||||
if (text[-1] not in splits): text += "。" if lang != "en" else "."
|
||||
texts.append(text)
|
||||
|
||||
# 解决句子过长导致Bert报错的问题
|
||||
texts.extend(split_big_text(text))
|
||||
|
||||
|
||||
return texts
|
||||
|
||||
@ -176,4 +186,8 @@ class TextPreprocessor:
|
||||
dtype=torch.float32,
|
||||
).to(self.device)
|
||||
|
||||
return feature
|
||||
return feature
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -24,6 +24,32 @@ def register_method(name):
|
||||
|
||||
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
||||
|
||||
def split_big_text(text, max_len=510):
|
||||
# 定义全角和半角标点符号
|
||||
punctuation = "".join(splits)
|
||||
|
||||
# 切割文本
|
||||
segments = re.split('([' + punctuation + '])', text)
|
||||
|
||||
# 初始化结果列表和当前片段
|
||||
result = []
|
||||
current_segment = ''
|
||||
|
||||
for segment in segments:
|
||||
# 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段
|
||||
if len(current_segment + segment) > max_len:
|
||||
result.append(current_segment)
|
||||
current_segment = segment
|
||||
else:
|
||||
current_segment += segment
|
||||
|
||||
# 将最后一个片段加入结果列表
|
||||
if current_segment:
|
||||
result.append(current_segment)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def split(todo_text):
|
||||
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||
@ -121,6 +147,6 @@ def cut5(inp):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
method = get_method("cut1")
|
||||
method = get_method("cut5")
|
||||
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))
|
||||
|
@ -29,9 +29,13 @@ is_share = eval(is_share)
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
is_half = eval(os.environ.get("is_half", "True")) and not torch.backends.mps.is_available()
|
||||
gpt_path = os.environ.get("gpt_path", None)
|
||||
sovits_path = os.environ.get("sovits_path", None)
|
||||
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
|
||||
bert_path = os.environ.get("bert_path", None)
|
||||
|
||||
import gradio as gr
|
||||
from TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from TTS_infer_pack.text_segmentation_method import cut1, cut2, cut3, cut4, cut5
|
||||
from TTS_infer_pack.text_segmentation_method import get_method
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
|
||||
@ -65,6 +69,15 @@ cut_method = {
|
||||
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
|
||||
tts_config.device = device
|
||||
tts_config.is_half = is_half
|
||||
if gpt_path is not None:
|
||||
tts_config.t2s_weights_path = gpt_path
|
||||
if sovits_path is not None:
|
||||
tts_config.vits_weights_path = sovits_path
|
||||
if cnhubert_base_path is not None:
|
||||
tts_config.cnhuhbert_base_path = cnhubert_base_path
|
||||
if bert_path is not None:
|
||||
tts_config.bert_base_path = bert_path
|
||||
|
||||
tts_pipline = TTS(tts_config)
|
||||
gpt_path = tts_config.t2s_weights_path
|
||||
sovits_path = tts_config.vits_weights_path
|
||||
@ -169,7 +182,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
with gr.Row():
|
||||
|
||||
with gr.Column():
|
||||
batch_size = gr.Slider(minimum=1,maximum=20,step=1,label=i18n("batch_size"),value=1,interactive=True)
|
||||
batch_size = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("batch_size"),value=20,interactive=True)
|
||||
speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True)
|
||||
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
||||
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
||||
@ -181,7 +194,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
value=i18n("凑四句一切"),
|
||||
interactive=True,
|
||||
)
|
||||
split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True)
|
||||
with gr.Row():
|
||||
split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True)
|
||||
# with gr.Column():
|
||||
output = gr.Audio(label=i18n("输出的语音"))
|
||||
with gr.Row():
|
||||
|
@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import math
|
||||
from typing import List
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
@ -985,6 +986,55 @@ class SynthesizerTrn(nn.Module):
|
||||
|
||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
return o
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def batched_decode(self, codes, y_lengths, text, text_lengths, refer, noise_scale=0.5):
|
||||
ge = None
|
||||
if refer is not None:
|
||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||
refer_mask = torch.unsqueeze(
|
||||
commons.sequence_mask(refer_lengths, refer.size(2)), 1
|
||||
).to(refer.dtype)
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
|
||||
# y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, codes.size(2)), 1).to(
|
||||
# codes.dtype
|
||||
# )
|
||||
y_lengths = (y_lengths * 2).long().to(codes.device)
|
||||
text_lengths = text_lengths.long().to(text.device)
|
||||
# y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||
# text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
# 假设padding之后再decode没有问题, 影响未知,但听起来好像没问题?
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(
|
||||
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
||||
)
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized, y_lengths, text, text_lengths, ge
|
||||
)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
z_masked = (z * y_mask)[:, :, :]
|
||||
|
||||
# 串行。把padding部分去掉再decode
|
||||
o_list:List[torch.Tensor] = []
|
||||
for i in range(z_masked.shape[0]):
|
||||
z_slice = z_masked[i, :, :y_lengths[i]].unsqueeze(0)
|
||||
o = self.dec(z_slice, g=ge)[0, 0, :].detach()
|
||||
o_list.append(o)
|
||||
|
||||
# 并行(会有问题)。先decode,再把padding的部分去掉
|
||||
# o = self.dec(z_masked, g=ge)
|
||||
# upsample_rate = int(math.prod(self.upsample_rates))
|
||||
# o_lengths = y_lengths*upsample_rate
|
||||
# o_list = [o[i, 0, :idx].detach() for i, idx in enumerate(o_lengths)]
|
||||
|
||||
return o_list
|
||||
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user