feat: allow fsdec and sdec to have sampling parames

This commit is contained in:
zpeng11 2025-08-23 12:17:04 -04:00
parent 3ccd1c0ea3
commit 9ed42daa88

View File

@ -28,45 +28,45 @@ def logits_to_probs(
logits, logits,
previous_tokens=None, previous_tokens=None,
temperature: float = 1.0, temperature: float = 1.0,
top_k=None, top_k=15,
top_p=None, top_p=1.0,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
): ):
previous_tokens = previous_tokens.squeeze() previous_tokens = previous_tokens.squeeze()
if previous_tokens is not None and repetition_penalty != 1.0: # if previous_tokens is not None and repetition_penalty != 1.0: # Always captured by onnx
previous_tokens = previous_tokens.long() previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens) score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.where( score = torch.where(
score < 0, score < 0,
score * repetition_penalty, score * repetition_penalty,
score / repetition_penalty, score / repetition_penalty,
) )
logits.scatter_(dim=0, index=previous_tokens, src=score) logits.scatter_(dim=0, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0: # if top_p is not None and top_p < 1.0: #To be captured by onnx
sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum( cum_probs = torch.cumsum(
torch.nn.functional.softmax( torch.nn.functional.softmax(
sorted_logits, sorted_logits,
dim=-1,
),
dim=-1, dim=-1,
) ),
sorted_indices_to_remove = cum_probs > top_p dim=-1,
sorted_indices_to_remove[0] = False # keep at least one option )
indices_to_remove = sorted_indices_to_remove.scatter( sorted_indices_to_remove = cum_probs > top_p
dim=0, sorted_indices_to_remove[0] = False # keep at least one option
index=sorted_indices, indices_to_remove = sorted_indices_to_remove.scatter(
src=sorted_indices_to_remove, dim=0,
) index=sorted_indices,
logits = logits.masked_fill(indices_to_remove, -float("Inf")) src=sorted_indices_to_remove,
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits = logits / max(temperature, 1e-5) logits = logits / max(temperature, 1e-5)
if top_k is not None: # if top_k is not None: # To be captured by onnx
v, _ = torch.topk(logits, top_k) v, _ = torch.topk(logits, top_k)
pivot = v.select(-1, -1).unsqueeze(-1) pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, inf_tensor_value, logits) logits = torch.where(logits < pivot, inf_tensor_value, logits)
probs = torch.nn.functional.softmax(logits, dim=-1) probs = torch.nn.functional.softmax(logits, dim=-1)
return probs return probs
@ -130,7 +130,16 @@ class T2SFirstStageDecoder(nn.Module):
self.early_stop_num = early_stop_num self.early_stop_num = early_stop_num
self.num_layers = num_layers self.num_layers = num_layers
def forward(self, x, prompt): def forward(self, x, prompt, top_k = None, top_p = None, repetition_penalty = None, temperature = None):
if top_k is None:
top_k = torch.LongTensor([15]).to(device=x.device)
if top_p is None:
top_p = torch.FloatTensor([1.0]).to(device=x.device)
if repetition_penalty is None:
repetition_penalty = torch.FloatTensor([1.0]).to(device=x.device)
if temperature is None:
temperature = torch.FloatTensor([1.0]).to(device=x.device)
y = prompt y = prompt
x_example = x[:, :, 0] * 0.0 x_example = x[:, :, 0] * 0.0
# N, 1, 512 # N, 1, 512
@ -180,7 +189,7 @@ class T2SFirstStageDecoder(nn.Module):
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1]) logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) samples = sample(logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)[0].unsqueeze(0)
y = torch.concat([y, samples], dim=1) y = torch.concat([y, samples], dim=1)
@ -211,7 +220,16 @@ class T2SStageDecoder(nn.Module):
self.early_stop_num = early_stop_num self.early_stop_num = early_stop_num
self.num_layers = num_layers self.num_layers = num_layers
def forward(self, y, k, v, y_emb, x_example): def forward(self, y, k, v, y_emb, x_example, top_k = None, top_p = None, repetition_penalty = None, temperature = None):
if top_k is None:
top_k = torch.LongTensor([15]).to(device=y.device)
if top_p is None:
top_p = torch.FloatTensor([1.0]).to(device=y.device)
if repetition_penalty is None:
repetition_penalty = torch.FloatTensor([1.0]).to(device=y.device)
if temperature is None:
temperature = torch.FloatTensor([1.0]).to(device=y.device)
cache = { cache = {
"all_stage": self.num_layers, "all_stage": self.num_layers,
"k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)), "k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
@ -240,7 +258,7 @@ class T2SStageDecoder(nn.Module):
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1]) logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) samples = sample(logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)[0].unsqueeze(0)
y = torch.concat([y, samples], dim=1) y = torch.concat([y, samples], dim=1)
@ -389,7 +407,7 @@ class Text2SemanticDecoder(nn.Module):
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool) xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1]) logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=1.0)[0].unsqueeze(0)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: