diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index 8b2fad11..be1f33ce 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -51,14 +51,13 @@ def resample_audio(audio: torch.Tensor, orig_sr: int, target_sr: int) -> torch.T audio = audio.unsqueeze(0) # audio shape: (batch, channels, samples) batch, channels, samples = audio.shape - new_samples = int(samples * target_sr / orig_sr) - audio = audio.view(batch * channels, 1, samples) - resampled = F.interpolate(audio, size=new_samples, mode='linear', align_corners=False) - resampled = resampled.view(batch, channels, new_samples) - if resampled.shape[0] == 1 and resampled.shape[1] == 1: - resampled = resampled.squeeze(0).squeeze(0) - elif resampled.shape[0] == 1: - resampled = resampled.squeeze(0) + # Reshape to combine batch and channels for interpolation + audio = audio.reshape(batch * channels, 1, samples) + # Use scale_factor instead of a computed size for ONNX export compatibility + resampled = F.interpolate(audio, scale_factor=target_sr / orig_sr, mode='linear', align_corners=False) + new_samples = resampled.shape[-1] + resampled = resampled.reshape(batch, channels, new_samples) + resampled = resampled.squeeze(0).squeeze(0) return resampled diff --git a/playground/freerun.py b/playground/freerun.py index 225395c6..66877ff3 100644 --- a/playground/freerun.py +++ b/playground/freerun.py @@ -7,7 +7,7 @@ import torch from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx -MODEL_PATH = "playground/v2pro_export/v2pro" +MODEL_PATH = "onnx/v2pro_export/v2pro" def audio_postprocess( audios,