feat:onnx friendly loop with same function

This commit is contained in:
zpeng11 2025-08-21 22:23:50 -04:00
parent 77794a5923
commit e8fdf472c0

View File

@ -357,9 +357,17 @@ class ResidualVectorQuantization(nn.Module):
return out_indices
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
# ONNX-friendly approach: use unbind instead of enumerate loop
indices_list = torch.unbind(q_indices, dim=0)
quantized_list = []
for i, indices in enumerate(indices_list):
if st + i < len(self.layers):
layer = self.layers[st + i]
quantized = layer.decode(indices)
quantized_out = quantized_out + quantized
quantized_list.append(quantized)
# Stack and sum instead of iterative addition
quantized_out = torch.stack(quantized_list, dim=0).sum(dim=0)
return quantized_out