From e8fdf472c0c41a5bc0bf36dddb78c8b801bae090 Mon Sep 17 00:00:00 2001 From: zpeng11 Date: Thu, 21 Aug 2025 22:23:50 -0400 Subject: [PATCH] feat:onnx friendly loop with same function --- GPT_SoVITS/module/core_vq.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/GPT_SoVITS/module/core_vq.py b/GPT_SoVITS/module/core_vq.py index b7dab317..876a6984 100644 --- a/GPT_SoVITS/module/core_vq.py +++ b/GPT_SoVITS/module/core_vq.py @@ -357,9 +357,17 @@ class ResidualVectorQuantization(nn.Module): return out_indices def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor: - quantized_out = torch.tensor(0.0, device=q_indices.device) - for i, indices in enumerate(q_indices): - layer = self.layers[st + i] - quantized = layer.decode(indices) - quantized_out = quantized_out + quantized + # ONNX-friendly approach: use unbind instead of enumerate loop + indices_list = torch.unbind(q_indices, dim=0) + quantized_list = [] + + for i, indices in enumerate(indices_list): + if st + i < len(self.layers): + layer = self.layers[st + i] + quantized = layer.decode(indices) + quantized_list.append(quantized) + + # Stack and sum instead of iterative addition + quantized_out = torch.stack(quantized_list, dim=0).sum(dim=0) + return quantized_out