mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 17:10:02 +08:00
feat:onnx friendly loop with same function
This commit is contained in:
parent
77794a5923
commit
e8fdf472c0
@ -357,9 +357,17 @@ class ResidualVectorQuantization(nn.Module):
|
|||||||
return out_indices
|
return out_indices
|
||||||
|
|
||||||
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
|
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
|
||||||
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
# ONNX-friendly approach: use unbind instead of enumerate loop
|
||||||
for i, indices in enumerate(q_indices):
|
indices_list = torch.unbind(q_indices, dim=0)
|
||||||
|
quantized_list = []
|
||||||
|
|
||||||
|
for i, indices in enumerate(indices_list):
|
||||||
|
if st + i < len(self.layers):
|
||||||
layer = self.layers[st + i]
|
layer = self.layers[st + i]
|
||||||
quantized = layer.decode(indices)
|
quantized = layer.decode(indices)
|
||||||
quantized_out = quantized_out + quantized
|
quantized_list.append(quantized)
|
||||||
|
|
||||||
|
# Stack and sum instead of iterative addition
|
||||||
|
quantized_out = torch.stack(quantized_list, dim=0).sum(dim=0)
|
||||||
|
|
||||||
return quantized_out
|
return quantized_out
|
||||||
|
Loading…
x
Reference in New Issue
Block a user