diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 95f14859..aef7825c 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -462,7 +462,7 @@ class Text2SemanticDecoder(nn.Module): value=True, ) y_attn_mask = F.pad( - torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0),# diagonal必须为0,否则会导致batch_size>1时的复读情况 + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0), (x_len, 0), value=False, )