[fast_inference] 优化batch inference的mask策略 (#1477)

* 优化了batch inference的mask策略,使音频合成的质量更加稳定;改善了一些代码逻辑。

* 删除无用代码
This commit is contained in:
ChasonJiang 2024-08-16 10:49:53 +08:00 committed by GitHub
parent 7c43b41e6d
commit f5a5f1890f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 153 additions and 79 deletions

2
.gitignore vendored
View File

@ -10,3 +10,5 @@ reference
GPT_weights
SoVITS_weights
TEMP
ffmpeg*
ffprobe*

View File

@ -1,9 +1,10 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
import math
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from typing import List
from typing import List, Optional
import torch
from tqdm import tqdm
@ -38,6 +39,34 @@ default_config = {
"EOS": 1024,
}
@torch.jit.script
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
if scale is None:
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
else:
scale_factor = scale
attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask, float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_weight.masked_fill_(attn_mask, 0)
else:
attn_mask[attn_mask!=float("-inf")] =0
attn_mask[attn_mask==float("-inf")] =1
attn_weight.masked_fill_(attn_mask, 0)
return attn_weight @ value
@torch.jit.script
class T2SMLP:
@ -86,10 +115,16 @@ class T2SBlock:
self.norm_eps2 = norm_eps2
@torch.jit.ignore
def to_mask(self, x, padding_mask):
return x*padding_mask if padding_mask is not None else x
def process_prompt(self, x, attn_mask : torch.Tensor, padding_mask:torch.Tensor=None):
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
if padding_mask is None:
return x
if padding_mask.dtype == torch.bool:
return x.masked_fill(padding_mask, 0)
else:
return x * padding_mask
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None):
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
@ -106,27 +141,48 @@ class T2SBlock:
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
attn = scaled_dot_product_attention(q, k, v, attn_mask)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
x = self.to_mask(x + attn, padding_mask)
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = self.to_mask(x + self.mlp.forward(self.to_mask(x, padding_mask)), padding_mask)
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
if padding_mask is not None:
for i in range(batch_size):
# mask = padding_mask[i,:,0]
idx = torch.where(padding_mask[i,:,0]==False)[0]
x_item = x[i,idx,:].unsqueeze(0)
attn_item = attn[i,idx,:].unsqueeze(0)
x_item = x_item + attn_item
x_item = F.layer_norm(
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x_item = x_item + self.mlp.forward(x_item)
x_item = F.layer_norm(
x_item,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
x[i,idx,:] = x_item.squeeze(0)
x = self.to_mask(x, padding_mask)
else:
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
def decode_next_token(self, x, k_cache, v_cache):
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1)
@ -141,7 +197,7 @@ class T2SBlock:
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v)
attn = scaled_dot_product_attention(q, k, v, attn_mask)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
@ -169,8 +225,8 @@ class T2STransformer:
self.blocks = blocks
def process_prompt(
self, x, attn_mask : torch.Tensor,
padding_mask : torch.Tensor=None,
self, x:torch.Tensor, attn_mask : torch.Tensor,
padding_mask : Optional[torch.Tensor]=None,
):
k_cache : List[torch.Tensor] = []
v_cache : List[torch.Tensor] = []
@ -181,10 +237,10 @@ class T2STransformer:
return x, k_cache, v_cache
def decode_next_token(
self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]
self, x:torch.Tensor, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor], attn_mask : Optional[torch.Tensor]=None,
):
for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask)
return x, k_cache, v_cache
@ -506,12 +562,12 @@ class Text2SemanticDecoder(nn.Module):
# 错位
return targets[:, :-1], targets[:, 1:]
def infer_panel_batch_infer_with_flash_attn(
def infer_panel_batch_infer(
self,
x:torch.LongTensor, #####全部文本token
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor,
bert_feature:List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
@ -519,30 +575,21 @@ class Text2SemanticDecoder(nn.Module):
repetition_penalty: float = 1.35,
**kwargs,
):
# # fp16 会对结果产生影响和没pad相比
# bert_feature_dtype = bert_feature[0].dtype
# if not hasattr(self.bert_proj, "dtype"):
# self.bert_proj.dtype = torch.float32
# self.bert_proj=self.bert_proj.float()
if prompts is None:
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
return self.infer_panel_0307(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
## 先对phones进行embedding、对bert_features进行project再pad到相同长度padding策略会影响T2S模型生成的结果。
## pad之后再进行Linear会有误差和没pad相比就离谱。。。
max_len = kwargs.get("max_len",x_lens.max())
# for x_item, bert_item in zip(x, bert_feature):
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
x_list = [self.ar_text_embedding(item) for item in x]
x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
x_list = []
for x_item, bert_item in zip(x, bert_feature):
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
x_item = self.ar_text_embedding(x_item.unsqueeze(0))
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
x_item = self.ar_text_position(x_item).squeeze(0)
x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item
x_list.append(x_item)
x = torch.stack(x_list, dim=0)
bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
bert_feature = torch.stack(bert_features_list, dim=0)
# bert_feature = self.bert_proj(bert_feature.transpose(1, 2).float()).to(dtype=bert_feature_dtype)
# x = self.ar_text_embedding(x)
x = x + bert_feature
x = self.ar_text_position(x)
# AR Decoder
y = prompts
@ -593,16 +640,17 @@ class Text2SemanticDecoder(nn.Module):
value=False,
)
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device)
# xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device)
_xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len)
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
_xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
for i in range(bsz):
l = x_lens[i]
_xy_padding_mask[i,l:max_len,:]=True
xy_attn_mask = xy_mask.logical_or(_xy_padding_mask)
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
xy_padding_mask = ~xy_padding_mask.view(bsz, src_len, 1).expand(-1, -1, self.model_dim)
xy_padding_mask = xy_padding_mask.to(dtype=x.dtype)
xy_attn_mask = xy_attn_mask.bool()
xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1).expand(-1, -1, self.model_dim)
###### decode #####
y_list = [None]*y.shape[0]
@ -612,27 +660,32 @@ class Text2SemanticDecoder(nn.Module):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask)
logits = self.ar_predict_layer(
xy_dec[:, -1]
)
if idx == 0:
xy_attn_mask = None
xy_attn_mask = F.pad(xy_attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
logits = logits[:, :-1]
else:
xy_attn_mask = F.pad(xy_attn_mask,(0,1),value=False)
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
y = torch.concat([y, samples], dim=1)
####### 移除batch中已经生成完毕的序列,进一步优化计算量
tokens = torch.argmax(logits, dim=-1)
reserved_idx_of_batch_for_y = None
if (self.EOS in samples[:, 0]) or \
(self.EOS in torch.argmax(logits, dim=-1)): ###如果生成到EOS则停止
l = samples[:, 0]==self.EOS
(self.EOS in tokens): ###如果生成到EOS则停止
l1 = samples[:, 0]==self.EOS
l2 = tokens==self.EOS
l = l1.logical_or(l2)
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
@ -647,6 +700,7 @@ class Text2SemanticDecoder(nn.Module):
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
xy_attn_mask = torch.index_select(xy_attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
if k_cache is not None :
for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
@ -682,13 +736,14 @@ class Text2SemanticDecoder(nn.Module):
if ref_free:
return y_list, [0]*x.shape[0]
# print(idx_list)
return y_list, idx_list
def infer_panel_0307(self,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor,
bert_feature:List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
@ -699,9 +754,9 @@ class Text2SemanticDecoder(nn.Module):
y_list = []
idx_list = []
for i in range(len(x)):
y, idx = self.infer_panel_with_flash_attn_only(x[i].unsqueeze(0),
y, idx = self.infer_panel_naive(x[i].unsqueeze(0),
x_lens[i],
prompts[i].unsqueeze(0),
prompts[i].unsqueeze(0) if prompts is not None else None,
bert_feature[i].unsqueeze(0),
top_k,
top_p,
@ -714,7 +769,7 @@ class Text2SemanticDecoder(nn.Module):
return y_list, idx_list
def infer_panel_with_flash_attn_only(
def infer_panel_naive(
self,
x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor,
@ -771,8 +826,9 @@ class Text2SemanticDecoder(nn.Module):
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).unsqueeze(0).expand(bsz*self.num_head, -1, -1).view(bsz, self.num_head, src_len, src_len).to(x.device)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
xy_attn_mask = xy_attn_mask.bool()
# new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
# xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
for idx in tqdm(range(1500)):
if xy_attn_mask is not None:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
@ -813,3 +869,18 @@ class Text2SemanticDecoder(nn.Module):
if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx - 1
def infer_panel(
self,
x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
):
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)

View File

@ -81,6 +81,8 @@ class TTS_Config:
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
}
configs:dict = None
languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
def __init__(self, configs: Union[dict, str]=None):
# 设置默认配置文件路径
@ -97,7 +99,7 @@ class TTS_Config:
if isinstance(configs, str):
self.configs_path = configs
configs:dict = self._load_configs(self.configs_path)
assert isinstance(configs, dict)
default_configs:dict = configs.get("default", None)
if default_configs is not None:
@ -138,8 +140,7 @@ class TTS_Config:
self.hop_length:int = 640
self.win_length:int = 2048
self.n_speakers:int = 300
self.languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
def _load_configs(self, configs_path: str)->dict:
@ -489,8 +490,8 @@ class TTS:
all_phones_len_list = []
all_bert_features_list = []
norm_text_batch = []
bert_max_len = 0
phones_max_len = 0
all_bert_max_len = 0
all_phones_max_len = 0
for item in item_list:
if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
@ -505,8 +506,8 @@ class TTS:
all_phones = phones
# norm_text = item["norm_text"]
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
phones_max_len = max(phones_max_len, phones.shape[-1])
all_bert_max_len = max(all_bert_max_len, all_bert_features.shape[-1])
all_phones_max_len = max(all_phones_max_len, all_phones.shape[-1])
phones_list.append(phones)
phones_len_list.append(phones.shape[-1])
@ -520,7 +521,7 @@ class TTS:
all_bert_features_batch = all_bert_features_list
max_len = max(bert_max_len, phones_max_len)
max_len = max(all_bert_max_len, all_phones_max_len)
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
#### 直接对phones和bert_features进行pad。padding策略会影响T2S模型生成的结果但不直接影响复读概率。影响复读概率的主要因素是mask的策略
# all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
@ -630,7 +631,7 @@ class TTS:
if parallel_infer:
print(i18n("并行推理模式已开启"))
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer_with_flash_attn
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
else:
print(i18n("并行推理模式已关闭"))
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_0307
@ -942,4 +943,4 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将管道输出解码为 NumPy 数组
processed_audio = np.frombuffer(out, np.int16)
return processed_audio
return processed_audio