mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-29 22:10:21 +08:00
optimize attention calc logic (#2010)
Co-authored-by: wangzeyuan <wangzeyuan@agora.io>
This commit is contained in:
parent
ca3cc4997a
commit
282ae1d9b2
@ -145,45 +145,21 @@ class T2SBlock:
|
|||||||
else:
|
else:
|
||||||
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
||||||
|
|
||||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
|
||||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
|
||||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||||
|
|
||||||
if padding_mask is not None:
|
x = x + attn
|
||||||
for i in range(batch_size):
|
x = F.layer_norm(
|
||||||
# mask = padding_mask[i,:,0]
|
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||||
if self.false.device!= padding_mask.device:
|
)
|
||||||
self.false = self.false.to(padding_mask.device)
|
x = x + self.mlp.forward(x)
|
||||||
idx = torch.where(padding_mask[i,:,0]==self.false)[0]
|
x = F.layer_norm(
|
||||||
x_item = x[i,idx,:].unsqueeze(0)
|
x,
|
||||||
attn_item = attn[i,idx,:].unsqueeze(0)
|
[self.hidden_dim],
|
||||||
x_item = x_item + attn_item
|
self.norm_w2,
|
||||||
x_item = F.layer_norm(
|
self.norm_b2,
|
||||||
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
self.norm_eps2,
|
||||||
)
|
)
|
||||||
x_item = x_item + self.mlp.forward(x_item)
|
|
||||||
x_item = F.layer_norm(
|
|
||||||
x_item,
|
|
||||||
[self.hidden_dim],
|
|
||||||
self.norm_w2,
|
|
||||||
self.norm_b2,
|
|
||||||
self.norm_eps2,
|
|
||||||
)
|
|
||||||
x[i,idx,:] = x_item.squeeze(0)
|
|
||||||
x = self.to_mask(x, padding_mask)
|
|
||||||
else:
|
|
||||||
x = x + attn
|
|
||||||
x = F.layer_norm(
|
|
||||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
|
||||||
)
|
|
||||||
x = x + self.mlp.forward(x)
|
|
||||||
x = F.layer_norm(
|
|
||||||
x,
|
|
||||||
[self.hidden_dim],
|
|
||||||
self.norm_w2,
|
|
||||||
self.norm_b2,
|
|
||||||
self.norm_eps2,
|
|
||||||
)
|
|
||||||
return x, k_cache, v_cache
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
|
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
|
||||||
@ -206,8 +182,7 @@ class T2SBlock:
|
|||||||
else:
|
else:
|
||||||
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
||||||
|
|
||||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
|
||||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
|
||||||
attn = F.linear(attn, self.out_w, self.out_b)
|
attn = F.linear(attn, self.out_w, self.out_b)
|
||||||
|
|
||||||
x = x + attn
|
x = x + attn
|
||||||
@ -662,7 +637,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
xy_attn_mask = xy_mask.logical_or(_xy_padding_mask)
|
xy_attn_mask = xy_mask.logical_or(_xy_padding_mask)
|
||||||
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
|
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
|
||||||
xy_attn_mask = xy_attn_mask.bool()
|
xy_attn_mask = xy_attn_mask.bool()
|
||||||
xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1).expand(-1, -1, self.model_dim)
|
xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1)
|
||||||
|
|
||||||
###### decode #####
|
###### decode #####
|
||||||
y_list = [None]*y.shape[0]
|
y_list = [None]*y.shape[0]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user