feat:remove prints

This commit is contained in:
zpeng11 2025-08-26 17:09:36 -04:00
parent 3e63595f0e
commit 337da7454e

View File

@ -196,7 +196,6 @@ class T2SModel(nn.Module):
empty_tensor,
top_k, top_p, repetition_penalty, temperature,
torch.LongTensor([1]), x_seq_len, y_seq_len)
print(y.shape, k.shape, v.shape, y_emb.shape, logits.shape, samples.shape)
k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1))
v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1))
y_seq_len = torch.Tensor([y.shape[1]]).to(torch.int64)