修复mask导致的复读

This commit is contained in:
ChasonJiang 2024-04-12 16:38:55 +08:00
parent 2788c756be
commit 84711e27f9

View File

@ -394,6 +394,7 @@ class Text2SemanticDecoder(nn.Module):
(0, y_len),
value=True,
)
x_attn_mask[:, x_len]=False
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
@ -463,7 +464,7 @@ class Text2SemanticDecoder(nn.Module):
targets.append(target.unsqueeze(0))
x_attn_mask = torch.zeros((x_len+(y_len-y_lens[i]), x_len+y_len), dtype=torch.bool).to(device)
x_attn_mask[:, -y_lens[i]:] = True
x_attn_mask[:, -y_lens[i]+1:] = True
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_lens[i], y_lens[i], dtype=torch.bool).to(device),
@ -650,7 +651,7 @@ class Text2SemanticDecoder(nn.Module):
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool).to(device),
diagonal=1,
diagonal=0,
),
(x_len, 0),
value=False,
@ -1014,7 +1015,7 @@ class Text2SemanticDecoder(nn.Module):
value=True,
)
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0),
(x_len, 0),
value=False,
)