Merge c8c5d6a4cc22d133990eb66b5d2883690bdcba96 into 959a2ddbeb36844afbd97aadabbaa17e39e3b616

This commit is contained in:
wzy3650 2025-03-04 14:19:39 +08:00 committed by GitHub
commit 827b84205e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -629,11 +629,6 @@ class Text2SemanticDecoder(nn.Module):
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
_xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
for i in range(bsz):
l = x_lens[i]
_xy_padding_mask[i,l:max_len,:]=True
xy_attn_mask = xy_mask.logical_or(_xy_padding_mask)
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
xy_attn_mask = xy_attn_mask.bool()
@ -645,7 +640,7 @@ class Text2SemanticDecoder(nn.Module):
idx_list = [None]*y.shape[0]
for idx in tqdm(range(1500)):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask, False)
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None, False)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask, False)
logits = self.ar_predict_layer(