mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
optimize attention calc logic (#2010)
Co-authored-by: wangzeyuan <wangzeyuan@agora.io>
This commit is contained in:
parent
ca3cc4997a
commit
282ae1d9b2
@ -145,45 +145,21 @@ class T2SBlock:
|
||||
else:
|
||||
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
|
||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||
|
||||
if padding_mask is not None:
|
||||
for i in range(batch_size):
|
||||
# mask = padding_mask[i,:,0]
|
||||
if self.false.device!= padding_mask.device:
|
||||
self.false = self.false.to(padding_mask.device)
|
||||
idx = torch.where(padding_mask[i,:,0]==self.false)[0]
|
||||
x_item = x[i,idx,:].unsqueeze(0)
|
||||
attn_item = attn[i,idx,:].unsqueeze(0)
|
||||
x_item = x_item + attn_item
|
||||
x_item = F.layer_norm(
|
||||
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x_item = x_item + self.mlp.forward(x_item)
|
||||
x_item = F.layer_norm(
|
||||
x_item,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
x[i,idx,:] = x_item.squeeze(0)
|
||||
x = self.to_mask(x, padding_mask)
|
||||
else:
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
|
||||
@ -206,8 +182,7 @@ class T2SBlock:
|
||||
else:
|
||||
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
|
||||
attn = F.linear(attn, self.out_w, self.out_b)
|
||||
|
||||
x = x + attn
|
||||
@ -662,7 +637,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
xy_attn_mask = xy_mask.logical_or(_xy_padding_mask)
|
||||
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
|
||||
xy_attn_mask = xy_attn_mask.bool()
|
||||
xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1).expand(-1, -1, self.model_dim)
|
||||
xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1)
|
||||
|
||||
###### decode #####
|
||||
y_list = [None]*y.shape[0]
|
||||
|
Loading…
x
Reference in New Issue
Block a user