failed , testing expand y

This commit is contained in:
zpeng11 2025-08-25 21:57:36 -04:00
parent c85ee3d521
commit 419909b443
2 changed files with 18 additions and 10 deletions

View File

@ -145,10 +145,12 @@ class T2SStageDecoder(nn.Module):
"y_emb": y_emb, "y_emb": y_emb,
"first_infer": first_infer, "first_infer": first_infer,
"stage": 0, "stage": 0,
"x_seq_len": x_seq_len,
"y_seq_len": y_seq_len,
} }
# 运行时判断对最后一个y还是整个y做embedding以正确应对首次和后续 # 运行时判断对最后一个y还是整个y做embedding以正确应对首次和后续
multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(y)[1] multipled = minus_one * first_infer * y_seq_len
index_offset = torch.min(minus_one, multipled) index_offset = torch.min(minus_one, multipled)
y_to_emb = y[:, index_offset:] y_to_emb = y[:, index_offset:]
# 对y输入进行embedding # 对y输入进行embedding
@ -165,7 +167,7 @@ class T2SStageDecoder(nn.Module):
xy_pos = torch.concat([x, y_pos], dim=1) xy_pos = torch.concat([x, y_pos], dim=1)
# 运行时判断对最后一个xy_pos还是整个xy_pos做self attention # 运行时判断对最后一个xy_pos还是整个xy_pos做self attention
multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(xy_pos)[1] multipled = minus_one * first_infer * (x_seq_len + y_seq_len) # xy_pos = 1 or x_seq_len + y_seq_len
index_offset = torch.min(minus_one, multipled) index_offset = torch.min(minus_one, multipled)
xy_pos = xy_pos[:, index_offset:] xy_pos = xy_pos[:, index_offset:]
@ -189,7 +191,7 @@ class T2SStageDecoder(nn.Module):
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
# 运行时判断attension mask使用最后一个还是整个 # 运行时判断attension mask使用最后一个还是整个
multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(xy_attn_mask)[0] multipled = minus_one * first_infer * (x_seq_len + y_seq_len)
index_offset = torch.min(minus_one, multipled) index_offset = torch.min(minus_one, multipled)
xy_attn_mask = xy_attn_mask[index_offset:, :] xy_attn_mask = xy_attn_mask[index_offset:, :]
@ -268,20 +270,26 @@ class Text2SemanticDecoder(nn.Module):
x = self.onnx_encoder(x, bert_feature) x = self.onnx_encoder(x, bert_feature)
init_k = torch.zeros((self.num_layers, (x.shape[1] + prompts.shape[1]), 1, 512), dtype=torch.float) x_seq_len = x.shape[1]
init_v = torch.zeros((self.num_layers, (x.shape[1] + prompts.shape[1]), 1, 512), dtype=torch.float) y_seq_len = prompts.shape[1]
init_k = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
init_v = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
empty_tensor = torch.empty((1,0,512)).to(torch.float)
y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompts, init_k, init_v, y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompts, init_k, init_v,
torch.empty((1,0,512)).to(torch.float), top_k=top_k, empty_tensor, top_k=top_k,
first_infer=torch.LongTensor([1]), first_infer=torch.LongTensor([1]),
x_seq_len=x.shape[1], y_seq_len=prompts.shape[1]) x_seq_len=x_seq_len, y_seq_len=y_seq_len)
stop = False stop = False
for idx in tqdm(range(1, 1500)): for idx in tqdm(range(1, 1500)):
k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)) k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1))
v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)) v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1))
enco = self.stage_decoder( torch.empty((1,0,512)).to(torch.float) ,y, k, v, y_emb, top_k=top_k, y_seq_len = y.shape[1]
first_infer=torch.LongTensor([0]), x_seq_len=x.shape[1], y_seq_len=y.shape[1]) enco = self.stage_decoder(empty_tensor, y, k, v, y_emb, top_k=top_k,
first_infer=torch.LongTensor([0]), x_seq_len=x_seq_len, y_seq_len=y_seq_len)
y, k, v, y_emb, logits, samples = enco y, k, v, y_emb, logits, samples = enco
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True stop = True

View File

@ -57,7 +57,7 @@ def multi_head_attention_forward_patched(
cache_v = cache["v"][cache["stage"]] cache_v = cache["v"][cache["stage"]]
# Magic to get an index of either -1 or -N according to if first_infer_mask is set # Magic to get an index of either -1 or -N according to if first_infer_mask is set
minus_one = torch.tensor([-1]).to(k.device).to(torch.int64) minus_one = torch.tensor([-1]).to(k.device).to(torch.int64)
multipled = minus_one * first_infer_mask * torch.onnx.operators.shape_as_tensor(query)[0] multipled = minus_one * first_infer_mask * (cache['x_seq_len'] + cache['y_seq_len'])
index_offset = torch.min(minus_one, multipled) index_offset = torch.min(minus_one, multipled)
cache_k[index_offset :, :, :] = k cache_k[index_offset :, :, :] = k
cache_v[index_offset :, :, :] = v cache_v[index_offset :, :, :] = v