feat: allow fsdec and sdec to have sampling parames

This commit is contained in:
zpeng11 2025-08-23 12:17:04 -04:00
parent 3ccd1c0ea3
commit 9ed42daa88

View File

@ -28,12 +28,12 @@ def logits_to_probs(
logits,
previous_tokens=None,
temperature: float = 1.0,
top_k=None,
top_p=None,
top_k=15,
top_p=1.0,
repetition_penalty: float = 1.0,
):
previous_tokens = previous_tokens.squeeze()
if previous_tokens is not None and repetition_penalty != 1.0:
# if previous_tokens is not None and repetition_penalty != 1.0: # Always captured by onnx
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.where(
@ -43,7 +43,7 @@ def logits_to_probs(
)
logits.scatter_(dim=0, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
# if top_p is not None and top_p < 1.0: #To be captured by onnx
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(
@ -63,7 +63,7 @@ def logits_to_probs(
logits = logits / max(temperature, 1e-5)
if top_k is not None:
# if top_k is not None: # To be captured by onnx
v, _ = torch.topk(logits, top_k)
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, inf_tensor_value, logits)
@ -130,7 +130,16 @@ class T2SFirstStageDecoder(nn.Module):
self.early_stop_num = early_stop_num
self.num_layers = num_layers
def forward(self, x, prompt):
def forward(self, x, prompt, top_k = None, top_p = None, repetition_penalty = None, temperature = None):
if top_k is None:
top_k = torch.LongTensor([15]).to(device=x.device)
if top_p is None:
top_p = torch.FloatTensor([1.0]).to(device=x.device)
if repetition_penalty is None:
repetition_penalty = torch.FloatTensor([1.0]).to(device=x.device)
if temperature is None:
temperature = torch.FloatTensor([1.0]).to(device=x.device)
y = prompt
x_example = x[:, :, 0] * 0.0
# N, 1, 512
@ -180,7 +189,7 @@ class T2SFirstStageDecoder(nn.Module):
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
samples = sample(logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)[0].unsqueeze(0)
y = torch.concat([y, samples], dim=1)
@ -211,7 +220,16 @@ class T2SStageDecoder(nn.Module):
self.early_stop_num = early_stop_num
self.num_layers = num_layers
def forward(self, y, k, v, y_emb, x_example):
def forward(self, y, k, v, y_emb, x_example, top_k = None, top_p = None, repetition_penalty = None, temperature = None):
if top_k is None:
top_k = torch.LongTensor([15]).to(device=y.device)
if top_p is None:
top_p = torch.FloatTensor([1.0]).to(device=y.device)
if repetition_penalty is None:
repetition_penalty = torch.FloatTensor([1.0]).to(device=y.device)
if temperature is None:
temperature = torch.FloatTensor([1.0]).to(device=y.device)
cache = {
"all_stage": self.num_layers,
"k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
@ -240,7 +258,7 @@ class T2SStageDecoder(nn.Module):
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
samples = sample(logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)[0].unsqueeze(0)
y = torch.concat([y, samples], dim=1)
@ -389,7 +407,7 @@ class Text2SemanticDecoder(nn.Module):
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=1.0)[0].unsqueeze(0)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: