From c8c5d6a4cc22d133990eb66b5d2883690bdcba96 Mon Sep 17 00:00:00 2001 From: wangzeyuan Date: Wed, 12 Feb 2025 12:31:06 +0800 Subject: [PATCH] remove unnecessary mask of batch infer --- GPT_SoVITS/AR/models/t2s_model.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 8c8ea1a..916f447 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -629,11 +629,6 @@ class Text2SemanticDecoder(nn.Module): xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device) _xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1) - - for i in range(bsz): - l = x_lens[i] - _xy_padding_mask[i,l:max_len,:]=True - xy_attn_mask = xy_mask.logical_or(_xy_padding_mask) xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1) xy_attn_mask = xy_attn_mask.bool() @@ -645,7 +640,7 @@ class Text2SemanticDecoder(nn.Module): idx_list = [None]*y.shape[0] for idx in tqdm(range(1500)): if idx == 0: - xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask, False) + xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None, False) else: xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask, False) logits = self.ar_predict_layer(