diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index f8f6582..5b76c50 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -629,11 +629,6 @@ class Text2SemanticDecoder(nn.Module): xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device) _xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1) - - for i in range(bsz): - l = x_lens[i] - _xy_padding_mask[i,l:max_len,:]=True - xy_attn_mask = xy_mask.logical_or(_xy_padding_mask) xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1) xy_attn_mask = xy_attn_mask.bool() @@ -645,7 +640,7 @@ class Text2SemanticDecoder(nn.Module): idx_list = [None]*y.shape[0] for idx in tqdm(range(1500)): if idx == 0: - xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask, False) + xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None, False) else: xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask, False) logits = self.ar_predict_layer(