failed , testing expand y

This commit is contained in:
zpeng11 2025-08-25 21:57:36 -04:00
parent c85ee3d521
commit 419909b443
2 changed files with 18 additions and 10 deletions

View File

@ -145,10 +145,12 @@ class T2SStageDecoder(nn.Module):
"y_emb": y_emb,
"first_infer": first_infer,
"stage": 0,
"x_seq_len": x_seq_len,
"y_seq_len": y_seq_len,
}
# 运行时判断对最后一个y还是整个y做embedding以正确应对首次和后续
multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(y)[1]
multipled = minus_one * first_infer * y_seq_len
index_offset = torch.min(minus_one, multipled)
y_to_emb = y[:, index_offset:]
# 对y输入进行embedding
@ -165,7 +167,7 @@ class T2SStageDecoder(nn.Module):
xy_pos = torch.concat([x, y_pos], dim=1)
# 运行时判断对最后一个xy_pos还是整个xy_pos做self attention
multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(xy_pos)[1]
multipled = minus_one * first_infer * (x_seq_len + y_seq_len) # xy_pos = 1 or x_seq_len + y_seq_len
index_offset = torch.min(minus_one, multipled)
xy_pos = xy_pos[:, index_offset:]
@ -189,7 +191,7 @@ class T2SStageDecoder(nn.Module):
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
# 运行时判断attension mask使用最后一个还是整个
multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(xy_attn_mask)[0]
multipled = minus_one * first_infer * (x_seq_len + y_seq_len)
index_offset = torch.min(minus_one, multipled)
xy_attn_mask = xy_attn_mask[index_offset:, :]
@ -268,20 +270,26 @@ class Text2SemanticDecoder(nn.Module):
x = self.onnx_encoder(x, bert_feature)
init_k = torch.zeros((self.num_layers, (x.shape[1] + prompts.shape[1]), 1, 512), dtype=torch.float)
init_v = torch.zeros((self.num_layers, (x.shape[1] + prompts.shape[1]), 1, 512), dtype=torch.float)
x_seq_len = x.shape[1]
y_seq_len = prompts.shape[1]
init_k = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
init_v = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
empty_tensor = torch.empty((1,0,512)).to(torch.float)
y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompts, init_k, init_v,
torch.empty((1,0,512)).to(torch.float), top_k=top_k,
empty_tensor, top_k=top_k,
first_infer=torch.LongTensor([1]),
x_seq_len=x.shape[1], y_seq_len=prompts.shape[1])
x_seq_len=x_seq_len, y_seq_len=y_seq_len)
stop = False
for idx in tqdm(range(1, 1500)):
k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1))
v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1))
enco = self.stage_decoder( torch.empty((1,0,512)).to(torch.float) ,y, k, v, y_emb, top_k=top_k,
first_infer=torch.LongTensor([0]), x_seq_len=x.shape[1], y_seq_len=y.shape[1])
y_seq_len = y.shape[1]
enco = self.stage_decoder(empty_tensor, y, k, v, y_emb, top_k=top_k,
first_infer=torch.LongTensor([0]), x_seq_len=x_seq_len, y_seq_len=y_seq_len)
y, k, v, y_emb, logits, samples = enco
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True

View File

@ -57,7 +57,7 @@ def multi_head_attention_forward_patched(
cache_v = cache["v"][cache["stage"]]
# Magic to get an index of either -1 or -N according to if first_infer_mask is set
minus_one = torch.tensor([-1]).to(k.device).to(torch.int64)
multipled = minus_one * first_infer_mask * torch.onnx.operators.shape_as_tensor(query)[0]
multipled = minus_one * first_infer_mask * (cache['x_seq_len'] + cache['y_seq_len'])
index_offset = torch.min(minus_one, multipled)
cache_k[index_offset :, :, :] = k
cache_v[index_offset :, :, :] = v