feat:clean up playground explore audio preprocess, todo:build free run from pure input data

This commit is contained in:
zpeng11 2025-08-19 01:22:00 -04:00
parent aef9d26580
commit dd156f15aa
9 changed files with 58 additions and 14 deletions

2
.gitignore vendored
View File

@ -193,6 +193,8 @@ cython_debug/
# PyPI configuration file # PyPI configuration file
.pypirc .pypirc
onnx/ onnx/
*.onnx *.onnx
tokenizer.json tokenizer.json
output.wav

View File

@ -67,7 +67,7 @@ class HubertONNXExporter:
print(f"[Error] ONNX model not found at {self.onnx_path}") print(f"[Error] ONNX model not found at {self.onnx_path}")
return False return False
def _load_and_preprocess_audio(self, audio_path, max_length=32000): def _load_and_preprocess_audio(self, audio_path, max_length=160000):
"""Load and preprocess audio file""" """Load and preprocess audio file"""
waveform, sample_rate = torchaudio.load(audio_path) waveform, sample_rate = torchaudio.load(audio_path)
@ -80,10 +80,18 @@ class HubertONNXExporter:
if waveform.shape[0] > 1: if waveform.shape[0] > 1:
waveform = waveform[0:1] waveform = waveform[0:1]
# Limit length for testing (2 seconds at 16kHz) # Limit length for testing (10 seconds at 16kHz)
if waveform.shape[1] > max_length: if waveform.shape[1] > max_length:
waveform = waveform[:, :max_length] waveform = waveform[:, :max_length]
# make a zero tensor that has length 3200*0.3
zero_tensor = torch.zeros((1, 9600), dtype=torch.float32)
print("waveform shape and zero wave shape", waveform.shape, zero_tensor.shape)
# concate zero_tensor with waveform
waveform = torch.cat([waveform, zero_tensor], dim=1)
return waveform return waveform
def test_torch_vs_onnx(self, audio_path="playground/ref/audio.wav"): def test_torch_vs_onnx(self, audio_path="playground/ref/audio.wav"):

View File

@ -29,17 +29,54 @@ def audio_postprocess(
return audio return audio
def load_and_preprocess_audio(audio_path, max_length=160000):
"""Load and preprocess audio file"""
waveform, sample_rate = torchaudio.load(audio_path)
input_phones = np.load("playground/input_phones.npy") # Resample to 16kHz if needed
input_bert = np.load("playground/input_bert.npy").T.astype(np.float32) if sample_rate != 16000:
ref_phones = np.load("playground/ref_phones.npy") resampler = torchaudio.transforms.Resample(sample_rate, 16000)
ref_bert = np.load("playground/ref_bert.npy").T.astype(np.float32) waveform = resampler(waveform)
audio_prompt_hubert = np.load("playground/audio_prompt_hubert.npy").astype(np.float32)
# Take first channel
if waveform.shape[0] > 1:
waveform = waveform[0:1]
# Limit length for testing (10 seconds at 16kHz)
if waveform.shape[1] > max_length:
waveform = waveform[:, :max_length]
# make a zero tensor that has length 3200*0.3
zero_tensor = torch.zeros((1, 9600), dtype=torch.float32)
print("waveform shape and zero wave shape", waveform.shape, zero_tensor.shape)
# concate zero_tensor with waveform
waveform = torch.cat([waveform, zero_tensor], dim=1)
return waveform
def get_audio_hubert(audio_path):
"""Get HuBERT features for the audio file"""
waveform = load_and_preprocess_audio(audio_path)
ort_session = ort.InferenceSession("playground/hubert/chinese-hubert-base.onnx")
ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()}
hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32)
# transpose axis 1 and 2 with numpy
hubert_feature = hubert_feature.transpose(0, 2, 1)
print("Hubert feature shape:", hubert_feature.shape)
return hubert_feature
input_phones = np.load("playground/ref/input_phones.npy")
input_bert = np.load("playground/ref/input_bert.npy").T.astype(np.float32)
ref_phones = np.load("playground/ref/ref_phones.npy")
ref_bert = np.load("playground/ref/ref_bert.npy").T.astype(np.float32)
audio_prompt_hubert = get_audio_hubert("playground/ref/audio.wav")
encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx") encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx")
outputs = encoder.run(None, { [x, prompts] = encoder.run(None, {
"text_seq": input_phones, "text_seq": input_phones,
"text_bert": input_bert, "text_bert": input_bert,
"ref_seq": ref_phones, "ref_seq": ref_phones,
@ -47,10 +84,7 @@ outputs = encoder.run(None, {
"ssl_content": audio_prompt_hubert "ssl_content": audio_prompt_hubert
}) })
print(outputs[0].shape, outputs[1].shape) print(x.shape, prompts.shape)
x = outputs[0]
prompts = outputs[1]
fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx") fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx")
sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx") sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")

Binary file not shown.