mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 00:30:15 +08:00
feat:clean up playground explore audio preprocess, todo:build free run from pure input data
This commit is contained in:
parent
aef9d26580
commit
dd156f15aa
4
.gitignore
vendored
4
.gitignore
vendored
@ -193,6 +193,8 @@ cython_debug/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
onnx/
|
||||
*.onnx
|
||||
tokenizer.json
|
||||
tokenizer.json
|
||||
output.wav
|
@ -67,7 +67,7 @@ class HubertONNXExporter:
|
||||
print(f"[Error] ONNX model not found at {self.onnx_path}")
|
||||
return False
|
||||
|
||||
def _load_and_preprocess_audio(self, audio_path, max_length=32000):
|
||||
def _load_and_preprocess_audio(self, audio_path, max_length=160000):
|
||||
"""Load and preprocess audio file"""
|
||||
waveform, sample_rate = torchaudio.load(audio_path)
|
||||
|
||||
@ -80,10 +80,18 @@ class HubertONNXExporter:
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform[0:1]
|
||||
|
||||
# Limit length for testing (2 seconds at 16kHz)
|
||||
# Limit length for testing (10 seconds at 16kHz)
|
||||
if waveform.shape[1] > max_length:
|
||||
waveform = waveform[:, :max_length]
|
||||
|
||||
|
||||
# make a zero tensor that has length 3200*0.3
|
||||
zero_tensor = torch.zeros((1, 9600), dtype=torch.float32)
|
||||
|
||||
print("waveform shape and zero wave shape", waveform.shape, zero_tensor.shape)
|
||||
|
||||
# concate zero_tensor with waveform
|
||||
waveform = torch.cat([waveform, zero_tensor], dim=1)
|
||||
|
||||
return waveform
|
||||
|
||||
def test_torch_vs_onnx(self, audio_path="playground/ref/audio.wav"):
|
||||
|
@ -29,17 +29,54 @@ def audio_postprocess(
|
||||
|
||||
return audio
|
||||
|
||||
def load_and_preprocess_audio(audio_path, max_length=160000):
|
||||
"""Load and preprocess audio file"""
|
||||
waveform, sample_rate = torchaudio.load(audio_path)
|
||||
|
||||
# Resample to 16kHz if needed
|
||||
if sample_rate != 16000:
|
||||
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
# Take first channel
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform[0:1]
|
||||
|
||||
# Limit length for testing (10 seconds at 16kHz)
|
||||
if waveform.shape[1] > max_length:
|
||||
waveform = waveform[:, :max_length]
|
||||
|
||||
input_phones = np.load("playground/input_phones.npy")
|
||||
input_bert = np.load("playground/input_bert.npy").T.astype(np.float32)
|
||||
ref_phones = np.load("playground/ref_phones.npy")
|
||||
ref_bert = np.load("playground/ref_bert.npy").T.astype(np.float32)
|
||||
audio_prompt_hubert = np.load("playground/audio_prompt_hubert.npy").astype(np.float32)
|
||||
# make a zero tensor that has length 3200*0.3
|
||||
zero_tensor = torch.zeros((1, 9600), dtype=torch.float32)
|
||||
|
||||
print("waveform shape and zero wave shape", waveform.shape, zero_tensor.shape)
|
||||
|
||||
# concate zero_tensor with waveform
|
||||
waveform = torch.cat([waveform, zero_tensor], dim=1)
|
||||
|
||||
return waveform
|
||||
|
||||
def get_audio_hubert(audio_path):
|
||||
"""Get HuBERT features for the audio file"""
|
||||
waveform = load_and_preprocess_audio(audio_path)
|
||||
ort_session = ort.InferenceSession("playground/hubert/chinese-hubert-base.onnx")
|
||||
ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()}
|
||||
hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32)
|
||||
# transpose axis 1 and 2 with numpy
|
||||
hubert_feature = hubert_feature.transpose(0, 2, 1)
|
||||
print("Hubert feature shape:", hubert_feature.shape)
|
||||
return hubert_feature
|
||||
|
||||
input_phones = np.load("playground/ref/input_phones.npy")
|
||||
input_bert = np.load("playground/ref/input_bert.npy").T.astype(np.float32)
|
||||
ref_phones = np.load("playground/ref/ref_phones.npy")
|
||||
ref_bert = np.load("playground/ref/ref_bert.npy").T.astype(np.float32)
|
||||
audio_prompt_hubert = get_audio_hubert("playground/ref/audio.wav")
|
||||
|
||||
|
||||
encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx")
|
||||
|
||||
outputs = encoder.run(None, {
|
||||
[x, prompts] = encoder.run(None, {
|
||||
"text_seq": input_phones,
|
||||
"text_bert": input_bert,
|
||||
"ref_seq": ref_phones,
|
||||
@ -47,10 +84,7 @@ outputs = encoder.run(None, {
|
||||
"ssl_content": audio_prompt_hubert
|
||||
})
|
||||
|
||||
print(outputs[0].shape, outputs[1].shape)
|
||||
|
||||
x = outputs[0]
|
||||
prompts = outputs[1]
|
||||
print(x.shape, prompts.shape)
|
||||
|
||||
fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx")
|
||||
sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
|
||||
|
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user