mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-08 16:00:01 +08:00
修复mask导致的复读
This commit is contained in:
parent
2788c756be
commit
84711e27f9
@ -394,6 +394,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
(0, y_len),
|
||||
value=True,
|
||||
)
|
||||
x_attn_mask[:, x_len]=False
|
||||
y_attn_mask = F.pad(
|
||||
torch.triu(
|
||||
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
|
||||
@ -463,7 +464,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
targets.append(target.unsqueeze(0))
|
||||
|
||||
x_attn_mask = torch.zeros((x_len+(y_len-y_lens[i]), x_len+y_len), dtype=torch.bool).to(device)
|
||||
x_attn_mask[:, -y_lens[i]:] = True
|
||||
x_attn_mask[:, -y_lens[i]+1:] = True
|
||||
y_attn_mask = F.pad(
|
||||
torch.triu(
|
||||
torch.ones(y_lens[i], y_lens[i], dtype=torch.bool).to(device),
|
||||
@ -650,7 +651,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_attn_mask = F.pad(
|
||||
torch.triu(
|
||||
torch.ones(y_len, y_len, dtype=torch.bool).to(device),
|
||||
diagonal=1,
|
||||
diagonal=0,
|
||||
),
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
@ -1014,7 +1015,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
value=True,
|
||||
)
|
||||
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0),
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user