diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 878af64e..c47d4ce4 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -394,6 +394,7 @@ class Text2SemanticDecoder(nn.Module): (0, y_len), value=True, ) + x_attn_mask[:, x_len]=False y_attn_mask = F.pad( torch.triu( torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), @@ -463,7 +464,7 @@ class Text2SemanticDecoder(nn.Module): targets.append(target.unsqueeze(0)) x_attn_mask = torch.zeros((x_len+(y_len-y_lens[i]), x_len+y_len), dtype=torch.bool).to(device) - x_attn_mask[:, -y_lens[i]:] = True + x_attn_mask[:, -y_lens[i]+1:] = True y_attn_mask = F.pad( torch.triu( torch.ones(y_lens[i], y_lens[i], dtype=torch.bool).to(device), @@ -650,7 +651,7 @@ class Text2SemanticDecoder(nn.Module): y_attn_mask = F.pad( torch.triu( torch.ones(y_len, y_len, dtype=torch.bool).to(device), - diagonal=1, + diagonal=0, ), (x_len, 0), value=False, @@ -1014,7 +1015,7 @@ class Text2SemanticDecoder(nn.Module): value=True, ) y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) - torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0), (x_len, 0), value=False, )