mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 00:30:15 +08:00
103 lines
2.9 KiB
Python
103 lines
2.9 KiB
Python
import onnxruntime as ort
|
|
import numpy as np
|
|
import onnx
|
|
from tqdm import tqdm
|
|
import torchaudio
|
|
import torch
|
|
|
|
MODEL_PATH = "playground/v2proplus_export/v2proplus"
|
|
|
|
def audio_postprocess(
|
|
audios,
|
|
fragment_interval: float = 0.3,
|
|
):
|
|
zero_wav = np.zeros((int(32000 * fragment_interval),)).astype(np.float32)
|
|
for i, audio in enumerate(audios):
|
|
max_audio = np.abs(audio).max() # 简单防止16bit爆音
|
|
if max_audio > 1:
|
|
audio /= max_audio
|
|
audio = np.concatenate([audio, zero_wav], axis=0)
|
|
audios[i] = audio
|
|
|
|
audio = np.concatenate(audios, axis=0)
|
|
|
|
# audio = (audio * 32768).astype(np.int16)
|
|
|
|
audio_tensor = torch.from_numpy(audio).unsqueeze(0)
|
|
|
|
torchaudio.save('playground/output.wav', audio_tensor, 32000)
|
|
|
|
return audio
|
|
|
|
|
|
input_phones = np.load("playground/input_phones.npy")
|
|
input_bert = np.load("playground/input_bert.npy").T.astype(np.float32)
|
|
ref_phones = np.load("playground/ref_phones.npy")
|
|
ref_bert = np.load("playground/ref_bert.npy").T.astype(np.float32)
|
|
audio_prompt_hubert = np.load("playground/audio_prompt_hubert.npy").astype(np.float32)
|
|
|
|
|
|
encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx")
|
|
|
|
outputs = encoder.run(None, {
|
|
"text_seq": input_phones,
|
|
"text_bert": input_bert,
|
|
"ref_seq": ref_phones,
|
|
"ref_bert": ref_bert,
|
|
"ssl_content": audio_prompt_hubert
|
|
})
|
|
|
|
print(outputs[0].shape, outputs[1].shape)
|
|
|
|
x = outputs[0]
|
|
prompts = outputs[1]
|
|
|
|
fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx")
|
|
sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
|
|
|
|
# for i in tqdm(range(10000)):
|
|
[y, k, v, y_emb, x_example] = fsdec.run(None, {
|
|
"x": x,
|
|
"prompts": prompts
|
|
})
|
|
|
|
early_stop_num = -1
|
|
prefix_len = prompts.shape[1]
|
|
|
|
stop = False
|
|
for idx in tqdm(range(1, 1500)):
|
|
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
|
[y, k, v, y_emb, logits, samples] = sdec.run(None, {
|
|
"iy": y,
|
|
"ik": k,
|
|
"iv": v,
|
|
"iy_emb": y_emb,
|
|
"ix_example": x_example
|
|
})
|
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
|
stop = True
|
|
if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024:
|
|
stop = True
|
|
if stop:
|
|
break
|
|
y[0, -1] = 0
|
|
|
|
|
|
pred_semantic = np.expand_dims(y[:, -idx:], axis=0)
|
|
# Read and resample reference audio
|
|
waveform, sample_rate = torchaudio.load("playground/ref/audio.wav")
|
|
if sample_rate != 32000:
|
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=32000)
|
|
waveform = resampler(waveform)
|
|
ref_audio = waveform.numpy().astype(np.float32)
|
|
vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx")
|
|
|
|
[audio] = vtis.run(None, {
|
|
"text_seq": input_phones,
|
|
"pred_semantic": pred_semantic,
|
|
"ref_audio": ref_audio
|
|
})
|
|
print(audio.shape, audio.dtype, audio.min(), audio.max())
|
|
|
|
audio_postprocess([audio])
|