run time working

This commit is contained in:
zpeng11 2025-08-25 17:07:38 -04:00
parent 26228402e3
commit d413a4f5b1

View File

@ -128,7 +128,7 @@ class T2SFirstStageDecoder(nn.Module):
self.early_stop_num = early_stop_num
self.num_layers = num_layers
def forward(self, x, prompt, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None):
def forward(self, x, prompt, y_emb, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None, x_seq_len = None, y_seq_len = None):
if top_k is None:
top_k = torch.LongTensor([15]).to(device=x.device)
if top_p is None:
@ -137,6 +137,7 @@ class T2SFirstStageDecoder(nn.Module):
repetition_penalty = torch.FloatTensor([1.0]).to(device=x.device)
if temperature is None:
temperature = torch.FloatTensor([1.0]).to(device=x.device)
minus_one = torch.tensor([-1]).to(x.device).to(torch.int64)
y = prompt
x_example = x[:, :, 0] * 0.0
@ -145,45 +146,59 @@ class T2SFirstStageDecoder(nn.Module):
"all_stage": self.num_layers,
"k": None,
"v": None,
"y_emb": None,
"y_emb": y_emb,
"first_infer": first_infer,
"stage": 0,
}
y_emb = self.ar_audio_embedding(y)
# 运行时判断对最后一个y还是整个y做embedding以正确应对首次和后续
multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(y)[1]
index_offset = torch.min(minus_one, multipled)
y_to_emb = y[:, index_offset:]
print("y_emb shape:", y_emb.shape)
y_emb = torch.cat(
[
cache["y_emb"],
self.ar_audio_embedding(y_to_emb),
],
1,
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
y_example = y_pos[:, :, 0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
# 运行时判断对最后一个xy_pos还是整个xy_pos做self attention
multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(xy_pos)[1]
index_offset = torch.min(minus_one, multipled)
xy_pos = xy_pos[:, index_offset:]
# 构造xy的attention mask
x_attn_mask = torch.zeros((x_seq_len, x_seq_len)).bool()
y_attn_mask = torch.ones((y_seq_len, y_seq_len)).to(torch.int64)
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
torch.ones_like(
y_example.transpose(0, 1),
torch.ones(
(y_seq_len, 1),
dtype=torch.int64,
),
dim=0,
)
y_attn_mask = y_attn_mask > 0
x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool()
y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool()
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
x_y_pad = torch.ones((x_seq_len, y_seq_len)).to(torch.bool)
y_x_pad = torch.zeros((y_seq_len, x_seq_len)).to(torch.bool)
x_attn_mask_pad = torch.cat([x_attn_mask, x_y_pad], dim=1)
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
cache["k"] = (
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
.unsqueeze(1)
.repeat(self.num_layers, 1, 1, 1)
)
cache["v"] = (
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
.unsqueeze(1)
.repeat(self.num_layers, 1, 1, 1)
)
print("first iter xy_attn_mask shape:", xy_attn_mask.shape)
cache["k"] = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
cache["v"] = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
print("first iter cache k shape:", cache["k"].shape, 'cache v shape:', cache["v"].shape)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
@ -216,7 +231,7 @@ class T2SStageDecoder(nn.Module):
self.early_stop_num = early_stop_num
self.num_layers = num_layers
def forward(self, y, k, v, y_emb, x_example, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None):
def forward(self, x, y, k, v, y_emb, x_example, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None):
if top_k is None:
top_k = torch.LongTensor([15]).to(device=y.device)
if top_p is None:
@ -225,6 +240,7 @@ class T2SStageDecoder(nn.Module):
repetition_penalty = torch.FloatTensor([1.0]).to(device=y.device)
if temperature is None:
temperature = torch.FloatTensor([1.0]).to(device=y.device)
minus_one = torch.tensor([-1]).to(y.device).to(torch.int64)
cache = {
"all_stage": self.num_layers,
@ -235,22 +251,32 @@ class T2SStageDecoder(nn.Module):
"stage": 0,
}
# 运行时判断对最后一个y还是整个y做embedding以正确应对首次和后续
multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(y)[1]
index_offset = torch.min(minus_one, multipled)
y_to_emb = y[:, index_offset:]
# 对y输入进行embedding
y_emb = torch.cat(
[
cache["y_emb"],
self.ar_audio_embedding(y[:, -1:]),
self.ar_audio_embedding(y_to_emb),
],
1,
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
# 与x输入拼接做attention准备
xy_pos = torch.concat([x, y_pos], dim=1)
xy_pos = y_pos[:, -1:]
# 运行时判断对最后一个xy_pos还是整个xy_pos做self attention
multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(xy_pos)[1]
index_offset = torch.min(minus_one, multipled)
xy_pos = xy_pos[:, index_offset:]
y_example = y_pos[:, :, 0] * 0.0
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
x_example_len = torch.onnx.operators.shape_as_tensor(x_example)[1]
y_example_len = torch.onnx.operators.shape_as_tensor(y_pos)[1]
xy_attn_mask = torch.zeros((1, x_example_len + y_example_len), dtype=torch.bool)
print('xy_attn_mask shape:', xy_attn_mask.shape)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
@ -336,11 +362,14 @@ class Text2SemanticDecoder(nn.Module):
prefix_len = prompts.shape[1]
x = self.onnx_encoder(x, bert_feature)
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts, top_k=top_k, first_infer=torch.LongTensor([1]))
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts, torch.empty((1,0,512)).to(torch.float), top_k=top_k,
first_infer=torch.LongTensor([1]),
x_seq_len=x.shape[1], y_seq_len=prompts.shape[1])
stop = False
for idx in tqdm(range(1, 1500)):
enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k, first_infer=torch.LongTensor([0]))
enco = self.stage_decoder( torch.empty((1,0,512)).to(torch.float) ,y, k, v, y_emb, x_example, top_k=top_k,
first_infer=torch.LongTensor([0]))
y, k, v, y_emb, logits, samples = enco
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True