feat:updated fsdecode and decoder interface

This commit is contained in:
zpeng11 2025-08-23 17:35:21 -04:00
parent b45cbc3561
commit 5982080939

View File

@ -61,7 +61,7 @@ def logits_to_probs(
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits = logits / torch.max(temperature, torch.tensor(1e-5, device=temperature.device, dtype=temperature.dtype))
logits = logits / torch.max(temperature, torch.tensor(1e-5, device=logits.device, dtype=torch.float))
# if top_k is not None: # To be captured by onnx
v, _ = torch.topk(logits, top_k)
@ -115,7 +115,6 @@ class T2SFirstStageDecoder(nn.Module):
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
top_k,
early_stop_num,
num_layers,
):
@ -126,7 +125,6 @@ class T2SFirstStageDecoder(nn.Module):
self.ar_predict_layer = ar_predict_layer
self.loss_fct = loss_fct
self.ar_accuracy_metric = ar_accuracy_metric
self.top_k = top_k
self.early_stop_num = early_stop_num
self.num_layers = num_layers
@ -205,7 +203,6 @@ class T2SStageDecoder(nn.Module):
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
top_k,
early_stop_num,
num_layers,
):
@ -216,7 +213,6 @@ class T2SStageDecoder(nn.Module):
self.ar_predict_layer = ar_predict_layer
self.loss_fct = loss_fct
self.ar_accuracy_metric = ar_accuracy_metric
self.top_k = top_k
self.early_stop_num = early_stop_num
self.num_layers = num_layers
@ -317,7 +313,6 @@ class Text2SemanticDecoder(nn.Module):
self.ar_predict_layer,
self.loss_fct,
self.ar_accuracy_metric,
self.top_k,
self.early_stop_num,
self.num_layers,
)
@ -328,23 +323,24 @@ class Text2SemanticDecoder(nn.Module):
self.ar_predict_layer,
self.loss_fct,
self.ar_accuracy_metric,
self.top_k,
self.early_stop_num,
self.num_layers,
)
def forward(self, x, prompts, bert_feature):
def forward(self, x, prompts, bert_feature, top_k = None):
# torch.manual_seed(42)
# torch.use_deterministic_algorithms(True)
if top_k is None:
top_k = self.top_k
early_stop_num = self.early_stop_num
prefix_len = prompts.shape[1]
x = self.onnx_encoder(x, bert_feature)
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts, top_k=top_k)
stop = False
for idx in tqdm(range(1, 1500)):
enco = self.stage_decoder(y, k, v, y_emb, x_example)
enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k)
y, k, v, y_emb, logits, samples = enco
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
@ -356,9 +352,10 @@ class Text2SemanticDecoder(nn.Module):
# torch.use_deterministic_algorithms(False)
return y, idx
def infer(self, x, prompts, bert_feature):
def infer(self, x, prompts, bert_feature, top_k=None):
# torch.manual_seed(42)
# torch.use_deterministic_algorithms(True)
if top_k is None:
top_k = self.top_k
early_stop_num = self.early_stop_num
@ -407,7 +404,7 @@ class Text2SemanticDecoder(nn.Module):
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=1.0)[0].unsqueeze(0)
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=torch.Tensor([1.0]))[0].unsqueeze(0)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: