diff --git a/.gitignore b/.gitignore index 754b06b..343d557 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,5 @@ reference GPT_weights SoVITS_weights TEMP +ffmpeg* +ffprobe* \ No newline at end of file diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index aef7825..cb952e6 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -1,9 +1,10 @@ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py # reference: https://github.com/lifeiteng/vall-e +import math import os, sys now_dir = os.getcwd() sys.path.append(now_dir) -from typing import List +from typing import List, Optional import torch from tqdm import tqdm @@ -38,6 +39,34 @@ default_config = { "EOS": 1024, } +@torch.jit.script +# Efficient implementation equivalent to the following: +def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor: + B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2) + if scale is None: + scale_factor = torch.tensor(1 / math.sqrt(query.size(-1))) + else: + scale_factor = scale + attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask, float("-inf")) + else: + attn_bias += attn_mask + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_weight.masked_fill_(attn_mask, 0) + else: + attn_mask[attn_mask!=float("-inf")] =0 + attn_mask[attn_mask==float("-inf")] =1 + attn_weight.masked_fill_(attn_mask, 0) + + return attn_weight @ value @torch.jit.script class T2SMLP: @@ -86,10 +115,16 @@ class T2SBlock: self.norm_eps2 = norm_eps2 @torch.jit.ignore - def to_mask(self, x, padding_mask): - return x*padding_mask if padding_mask is not None else x - - def process_prompt(self, x, attn_mask : torch.Tensor, padding_mask:torch.Tensor=None): + def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]): + if padding_mask is None: + return x + + if padding_mask.dtype == torch.bool: + return x.masked_fill(padding_mask, 0) + else: + return x * padding_mask + + def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None): q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1) @@ -106,27 +141,48 @@ class T2SBlock: k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) - attn = F.scaled_dot_product_attention(q, k, v, attn_mask) + attn = scaled_dot_product_attention(q, k, v, attn_mask) attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim) attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b) - x = self.to_mask(x + attn, padding_mask) - x = F.layer_norm( - x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 - ) - x = self.to_mask(x + self.mlp.forward(self.to_mask(x, padding_mask)), padding_mask) - x = F.layer_norm( - x, - [self.hidden_dim], - self.norm_w2, - self.norm_b2, - self.norm_eps2, - ) + if padding_mask is not None: + for i in range(batch_size): + # mask = padding_mask[i,:,0] + idx = torch.where(padding_mask[i,:,0]==False)[0] + x_item = x[i,idx,:].unsqueeze(0) + attn_item = attn[i,idx,:].unsqueeze(0) + x_item = x_item + attn_item + x_item = F.layer_norm( + x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 + ) + x_item = x_item + self.mlp.forward(x_item) + x_item = F.layer_norm( + x_item, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) + x[i,idx,:] = x_item.squeeze(0) + x = self.to_mask(x, padding_mask) + else: + x = x + attn + x = F.layer_norm( + x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 + ) + x = x + self.mlp.forward(x) + x = F.layer_norm( + x, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) return x, k_cache, v_cache - - def decode_next_token(self, x, k_cache, v_cache): + + def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None): q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) k_cache = torch.cat([k_cache, k], dim=1) @@ -141,7 +197,7 @@ class T2SBlock: v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) - attn = F.scaled_dot_product_attention(q, k, v) + attn = scaled_dot_product_attention(q, k, v, attn_mask) attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim) attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) @@ -169,8 +225,8 @@ class T2STransformer: self.blocks = blocks def process_prompt( - self, x, attn_mask : torch.Tensor, - padding_mask : torch.Tensor=None, + self, x:torch.Tensor, attn_mask : torch.Tensor, + padding_mask : Optional[torch.Tensor]=None, ): k_cache : List[torch.Tensor] = [] v_cache : List[torch.Tensor] = [] @@ -181,10 +237,10 @@ class T2STransformer: return x, k_cache, v_cache def decode_next_token( - self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor] + self, x:torch.Tensor, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor], attn_mask : Optional[torch.Tensor]=None, ): for i in range(self.num_blocks): - x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i]) + x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask) return x, k_cache, v_cache @@ -506,12 +562,12 @@ class Text2SemanticDecoder(nn.Module): # 错位 return targets[:, :-1], targets[:, 1:] - def infer_panel_batch_infer_with_flash_attn( + def infer_panel_batch_infer( self, - x:torch.LongTensor, #####全部文本token + x:List[torch.LongTensor], #####全部文本token x_lens:torch.LongTensor, prompts:torch.LongTensor, ####参考音频token - bert_feature:torch.LongTensor, + bert_feature:List[torch.LongTensor], top_k: int = -100, top_p: int = 100, early_stop_num: int = -1, @@ -519,30 +575,21 @@ class Text2SemanticDecoder(nn.Module): repetition_penalty: float = 1.35, **kwargs, ): - # # fp16 会对结果产生影响(和没pad相比) - # bert_feature_dtype = bert_feature[0].dtype - # if not hasattr(self.bert_proj, "dtype"): - # self.bert_proj.dtype = torch.float32 - # self.bert_proj=self.bert_proj.float() + if prompts is None: + print("Warning: Prompt free is not supported batch_infer! switch to naive_infer") + return self.infer_panel_0307(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs) - ## 先对phones进行embedding、对bert_features进行project,再pad到相同长度(padding策略会影响T2S模型生成的结果。) - ## pad之后再进行Linear会有误差(和没pad相比),就离谱。。。 max_len = kwargs.get("max_len",x_lens.max()) - # for x_item, bert_item in zip(x, bert_feature): - # max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) - x_list = [self.ar_text_embedding(item) for item in x] - x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]dict: @@ -489,8 +490,8 @@ class TTS: all_phones_len_list = [] all_bert_features_list = [] norm_text_batch = [] - bert_max_len = 0 - phones_max_len = 0 + all_bert_max_len = 0 + all_phones_max_len = 0 for item in item_list: if prompt_data is not None: all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\ @@ -505,8 +506,8 @@ class TTS: all_phones = phones # norm_text = item["norm_text"] - bert_max_len = max(bert_max_len, all_bert_features.shape[-1]) - phones_max_len = max(phones_max_len, phones.shape[-1]) + all_bert_max_len = max(all_bert_max_len, all_bert_features.shape[-1]) + all_phones_max_len = max(all_phones_max_len, all_phones.shape[-1]) phones_list.append(phones) phones_len_list.append(phones.shape[-1]) @@ -520,7 +521,7 @@ class TTS: all_bert_features_batch = all_bert_features_list - max_len = max(bert_max_len, phones_max_len) + max_len = max(all_bert_max_len, all_phones_max_len) # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) #### 直接对phones和bert_features进行pad。(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略) # all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) @@ -630,7 +631,7 @@ class TTS: if parallel_infer: print(i18n("并行推理模式已开启")) - self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer_with_flash_attn + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer else: print(i18n("并行推理模式已关闭")) self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_0307 @@ -942,4 +943,4 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int): # 将管道输出解码为 NumPy 数组 processed_audio = np.frombuffer(out, np.int16) - return processed_audio + return processed_audio \ No newline at end of file