From 84711e27f9ccb390b3a24f68f66138aa741e0cbb Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Fri, 12 Apr 2024 16:38:55 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dmask=E5=AF=BC=E8=87=B4?= =?UTF-8?q?=E7=9A=84=E5=A4=8D=E8=AF=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/AR/models/t2s_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 878af64e..c47d4ce4 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -394,6 +394,7 @@ class Text2SemanticDecoder(nn.Module): (0, y_len), value=True, ) + x_attn_mask[:, x_len]=False y_attn_mask = F.pad( torch.triu( torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), @@ -463,7 +464,7 @@ class Text2SemanticDecoder(nn.Module): targets.append(target.unsqueeze(0)) x_attn_mask = torch.zeros((x_len+(y_len-y_lens[i]), x_len+y_len), dtype=torch.bool).to(device) - x_attn_mask[:, -y_lens[i]:] = True + x_attn_mask[:, -y_lens[i]+1:] = True y_attn_mask = F.pad( torch.triu( torch.ones(y_lens[i], y_lens[i], dtype=torch.bool).to(device), @@ -650,7 +651,7 @@ class Text2SemanticDecoder(nn.Module): y_attn_mask = F.pad( torch.triu( torch.ones(y_len, y_len, dtype=torch.bool).to(device), - diagonal=1, + diagonal=0, ), (x_len, 0), value=False, @@ -1014,7 +1015,7 @@ class Text2SemanticDecoder(nn.Module): value=True, ) y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) - torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0), (x_len, 0), value=False, )