diff --git a/.gitignore b/.gitignore index e9721828..d645f020 100644 --- a/.gitignore +++ b/.gitignore @@ -193,6 +193,8 @@ cython_debug/ # PyPI configuration file .pypirc + onnx/ *.onnx -tokenizer.json \ No newline at end of file +tokenizer.json +output.wav \ No newline at end of file diff --git a/playground/export_hubert.py b/playground/export_hubert.py index 75c9be38..463c6c9b 100644 --- a/playground/export_hubert.py +++ b/playground/export_hubert.py @@ -67,7 +67,7 @@ class HubertONNXExporter: print(f"[Error] ONNX model not found at {self.onnx_path}") return False - def _load_and_preprocess_audio(self, audio_path, max_length=32000): + def _load_and_preprocess_audio(self, audio_path, max_length=160000): """Load and preprocess audio file""" waveform, sample_rate = torchaudio.load(audio_path) @@ -80,10 +80,18 @@ class HubertONNXExporter: if waveform.shape[0] > 1: waveform = waveform[0:1] - # Limit length for testing (2 seconds at 16kHz) + # Limit length for testing (10 seconds at 16kHz) if waveform.shape[1] > max_length: waveform = waveform[:, :max_length] - + + # make a zero tensor that has length 3200*0.3 + zero_tensor = torch.zeros((1, 9600), dtype=torch.float32) + + print("waveform shape and zero wave shape", waveform.shape, zero_tensor.shape) + + # concate zero_tensor with waveform + waveform = torch.cat([waveform, zero_tensor], dim=1) + return waveform def test_torch_vs_onnx(self, audio_path="playground/ref/audio.wav"): diff --git a/playground/freerun.py b/playground/freerun.py index 6adcd8f4..2bad817a 100644 --- a/playground/freerun.py +++ b/playground/freerun.py @@ -29,17 +29,54 @@ def audio_postprocess( return audio +def load_and_preprocess_audio(audio_path, max_length=160000): + """Load and preprocess audio file""" + waveform, sample_rate = torchaudio.load(audio_path) + + # Resample to 16kHz if needed + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(sample_rate, 16000) + waveform = resampler(waveform) + + # Take first channel + if waveform.shape[0] > 1: + waveform = waveform[0:1] + + # Limit length for testing (10 seconds at 16kHz) + if waveform.shape[1] > max_length: + waveform = waveform[:, :max_length] -input_phones = np.load("playground/input_phones.npy") -input_bert = np.load("playground/input_bert.npy").T.astype(np.float32) -ref_phones = np.load("playground/ref_phones.npy") -ref_bert = np.load("playground/ref_bert.npy").T.astype(np.float32) -audio_prompt_hubert = np.load("playground/audio_prompt_hubert.npy").astype(np.float32) + # make a zero tensor that has length 3200*0.3 + zero_tensor = torch.zeros((1, 9600), dtype=torch.float32) + + print("waveform shape and zero wave shape", waveform.shape, zero_tensor.shape) + + # concate zero_tensor with waveform + waveform = torch.cat([waveform, zero_tensor], dim=1) + + return waveform + +def get_audio_hubert(audio_path): + """Get HuBERT features for the audio file""" + waveform = load_and_preprocess_audio(audio_path) + ort_session = ort.InferenceSession("playground/hubert/chinese-hubert-base.onnx") + ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()} + hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32) + # transpose axis 1 and 2 with numpy + hubert_feature = hubert_feature.transpose(0, 2, 1) + print("Hubert feature shape:", hubert_feature.shape) + return hubert_feature + +input_phones = np.load("playground/ref/input_phones.npy") +input_bert = np.load("playground/ref/input_bert.npy").T.astype(np.float32) +ref_phones = np.load("playground/ref/ref_phones.npy") +ref_bert = np.load("playground/ref/ref_bert.npy").T.astype(np.float32) +audio_prompt_hubert = get_audio_hubert("playground/ref/audio.wav") encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx") -outputs = encoder.run(None, { +[x, prompts] = encoder.run(None, { "text_seq": input_phones, "text_bert": input_bert, "ref_seq": ref_phones, @@ -47,10 +84,7 @@ outputs = encoder.run(None, { "ssl_content": audio_prompt_hubert }) -print(outputs[0].shape, outputs[1].shape) - -x = outputs[0] -prompts = outputs[1] +print(x.shape, prompts.shape) fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx") sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx") diff --git a/playground/output.wav b/playground/output.wav index 7ff11483..d213ae43 100644 Binary files a/playground/output.wav and b/playground/output.wav differ diff --git a/playground/audio_prompt_hubert.npy b/playground/ref/audio_prompt_hubert.npy similarity index 100% rename from playground/audio_prompt_hubert.npy rename to playground/ref/audio_prompt_hubert.npy diff --git a/playground/input_bert.npy b/playground/ref/input_bert.npy similarity index 100% rename from playground/input_bert.npy rename to playground/ref/input_bert.npy diff --git a/playground/input_phones.npy b/playground/ref/input_phones.npy similarity index 100% rename from playground/input_phones.npy rename to playground/ref/input_phones.npy diff --git a/playground/ref_bert.npy b/playground/ref/ref_bert.npy similarity index 100% rename from playground/ref_bert.npy rename to playground/ref/ref_bert.npy diff --git a/playground/ref_phones.npy b/playground/ref/ref_phones.npy similarity index 100% rename from playground/ref_phones.npy rename to playground/ref/ref_phones.npy