mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 08:49:59 +08:00
feat:onnx friendly loop with same function
This commit is contained in:
parent
77794a5923
commit
e8fdf472c0
@ -357,9 +357,17 @@ class ResidualVectorQuantization(nn.Module):
|
||||
return out_indices
|
||||
|
||||
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
|
||||
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
||||
for i, indices in enumerate(q_indices):
|
||||
layer = self.layers[st + i]
|
||||
quantized = layer.decode(indices)
|
||||
quantized_out = quantized_out + quantized
|
||||
# ONNX-friendly approach: use unbind instead of enumerate loop
|
||||
indices_list = torch.unbind(q_indices, dim=0)
|
||||
quantized_list = []
|
||||
|
||||
for i, indices in enumerate(indices_list):
|
||||
if st + i < len(self.layers):
|
||||
layer = self.layers[st + i]
|
||||
quantized = layer.decode(indices)
|
||||
quantized_list.append(quantized)
|
||||
|
||||
# Stack and sum instead of iterative addition
|
||||
quantized_out = torch.stack(quantized_list, dim=0).sum(dim=0)
|
||||
|
||||
return quantized_out
|
||||
|
Loading…
x
Reference in New Issue
Block a user