feat:clean up playground explore audio preprocess, todo:build free run from pure input data

This commit is contained in:
zpeng11 2025-08-19 01:22:00 -04:00
parent aef9d26580
commit dd156f15aa
9 changed files with 58 additions and 14 deletions

4
.gitignore vendored
View File

@ -193,6 +193,8 @@ cython_debug/
# PyPI configuration file
.pypirc
onnx/
*.onnx
tokenizer.json
tokenizer.json
output.wav

View File

@ -67,7 +67,7 @@ class HubertONNXExporter:
print(f"[Error] ONNX model not found at {self.onnx_path}")
return False
def _load_and_preprocess_audio(self, audio_path, max_length=32000):
def _load_and_preprocess_audio(self, audio_path, max_length=160000):
"""Load and preprocess audio file"""
waveform, sample_rate = torchaudio.load(audio_path)
@ -80,10 +80,18 @@ class HubertONNXExporter:
if waveform.shape[0] > 1:
waveform = waveform[0:1]
# Limit length for testing (2 seconds at 16kHz)
# Limit length for testing (10 seconds at 16kHz)
if waveform.shape[1] > max_length:
waveform = waveform[:, :max_length]
# make a zero tensor that has length 3200*0.3
zero_tensor = torch.zeros((1, 9600), dtype=torch.float32)
print("waveform shape and zero wave shape", waveform.shape, zero_tensor.shape)
# concate zero_tensor with waveform
waveform = torch.cat([waveform, zero_tensor], dim=1)
return waveform
def test_torch_vs_onnx(self, audio_path="playground/ref/audio.wav"):

View File

@ -29,17 +29,54 @@ def audio_postprocess(
return audio
def load_and_preprocess_audio(audio_path, max_length=160000):
"""Load and preprocess audio file"""
waveform, sample_rate = torchaudio.load(audio_path)
# Resample to 16kHz if needed
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
# Take first channel
if waveform.shape[0] > 1:
waveform = waveform[0:1]
# Limit length for testing (10 seconds at 16kHz)
if waveform.shape[1] > max_length:
waveform = waveform[:, :max_length]
input_phones = np.load("playground/input_phones.npy")
input_bert = np.load("playground/input_bert.npy").T.astype(np.float32)
ref_phones = np.load("playground/ref_phones.npy")
ref_bert = np.load("playground/ref_bert.npy").T.astype(np.float32)
audio_prompt_hubert = np.load("playground/audio_prompt_hubert.npy").astype(np.float32)
# make a zero tensor that has length 3200*0.3
zero_tensor = torch.zeros((1, 9600), dtype=torch.float32)
print("waveform shape and zero wave shape", waveform.shape, zero_tensor.shape)
# concate zero_tensor with waveform
waveform = torch.cat([waveform, zero_tensor], dim=1)
return waveform
def get_audio_hubert(audio_path):
"""Get HuBERT features for the audio file"""
waveform = load_and_preprocess_audio(audio_path)
ort_session = ort.InferenceSession("playground/hubert/chinese-hubert-base.onnx")
ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()}
hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32)
# transpose axis 1 and 2 with numpy
hubert_feature = hubert_feature.transpose(0, 2, 1)
print("Hubert feature shape:", hubert_feature.shape)
return hubert_feature
input_phones = np.load("playground/ref/input_phones.npy")
input_bert = np.load("playground/ref/input_bert.npy").T.astype(np.float32)
ref_phones = np.load("playground/ref/ref_phones.npy")
ref_bert = np.load("playground/ref/ref_bert.npy").T.astype(np.float32)
audio_prompt_hubert = get_audio_hubert("playground/ref/audio.wav")
encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx")
outputs = encoder.run(None, {
[x, prompts] = encoder.run(None, {
"text_seq": input_phones,
"text_bert": input_bert,
"ref_seq": ref_phones,
@ -47,10 +84,7 @@ outputs = encoder.run(None, {
"ssl_content": audio_prompt_hubert
})
print(outputs[0].shape, outputs[1].shape)
x = outputs[0]
prompts = outputs[1]
print(x.shape, prompts.shape)
fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx")
sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")

Binary file not shown.