diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index fb52891..8c8ea1a 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -145,45 +145,21 @@ class T2SBlock: else: attn = scaled_dot_product_attention(q, k, v, attn_mask) - attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim) - attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1) attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b) - if padding_mask is not None: - for i in range(batch_size): - # mask = padding_mask[i,:,0] - if self.false.device!= padding_mask.device: - self.false = self.false.to(padding_mask.device) - idx = torch.where(padding_mask[i,:,0]==self.false)[0] - x_item = x[i,idx,:].unsqueeze(0) - attn_item = attn[i,idx,:].unsqueeze(0) - x_item = x_item + attn_item - x_item = F.layer_norm( - x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 - ) - x_item = x_item + self.mlp.forward(x_item) - x_item = F.layer_norm( - x_item, - [self.hidden_dim], - self.norm_w2, - self.norm_b2, - self.norm_eps2, - ) - x[i,idx,:] = x_item.squeeze(0) - x = self.to_mask(x, padding_mask) - else: - x = x + attn - x = F.layer_norm( - x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 - ) - x = x + self.mlp.forward(x) - x = F.layer_norm( - x, - [self.hidden_dim], - self.norm_w2, - self.norm_b2, - self.norm_eps2, - ) + x = x + attn + x = F.layer_norm( + x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 + ) + x = x + self.mlp.forward(x) + x = F.layer_norm( + x, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) return x, k_cache, v_cache def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True): @@ -206,8 +182,7 @@ class T2SBlock: else: attn = scaled_dot_product_attention(q, k, v, attn_mask) - attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim) - attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1) attn = F.linear(attn, self.out_w, self.out_b) x = x + attn @@ -662,7 +637,7 @@ class Text2SemanticDecoder(nn.Module): xy_attn_mask = xy_mask.logical_or(_xy_padding_mask) xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1) xy_attn_mask = xy_attn_mask.bool() - xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1).expand(-1, -1, self.model_dim) + xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1) ###### decode ##### y_list = [None]*y.shape[0]