diff --git a/.gitignore b/.gitignore index 473c09c3..cefe803f 100644 --- a/.gitignore +++ b/.gitignore @@ -193,4 +193,5 @@ cython_debug/ # PyPI configuration file .pypirc -onnx/ \ No newline at end of file +onnx/ +*.onnx \ No newline at end of file diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index 661c2edd..97b56b93 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -1,4 +1,5 @@ import torch +import torch.nn.functional as F import torchaudio from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule from feature_extractor import cnhubert @@ -39,6 +40,27 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) return spec +def resample_audio(audio: torch.Tensor, orig_sr: int, target_sr: int) -> torch.Tensor: + """ + Resample audio from orig_sr to target_sr using linear interpolation. + audio: (batch, channels, samples) or (channels, samples) or (samples,) + """ + if audio.dim() == 1: + audio = audio.unsqueeze(0).unsqueeze(0) + elif audio.dim() == 2: + audio = audio.unsqueeze(0) + # audio shape: (batch, channels, samples) + batch, channels, samples = audio.shape + new_samples = int(samples * target_sr / orig_sr) + audio = audio.view(batch * channels, 1, samples) + resampled = F.interpolate(audio, size=new_samples, mode='linear', align_corners=False) + resampled = resampled.view(batch, channels, new_samples) + if resampled.shape[0] == 1 and resampled.shape[1] == 1: + resampled = resampled.squeeze(0).squeeze(0) + elif resampled.shape[0] == 1: + resampled = resampled.squeeze(0) + return resampled + class DictToAttrRecursive(dict): def __init__(self, input_dict): @@ -225,7 +247,7 @@ class VitsModel(nn.Module): center=False, ) if self.sv_model is not None: - sv_emb=self.sv_model.compute_embedding3_onnx(ref_audio) + sv_emb=self.sv_model.compute_embedding3_onnx(resample_audio(ref_audio, 32000, 16000)) return self.vq_model(pred_semantic, text_seq, refer, sv_emb=sv_emb)[0, 0] return self.vq_model(pred_semantic, text_seq, refer)[0, 0] @@ -292,7 +314,7 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"): "y", "e4", ], - version=voice_model_version, + version='v2', ) ] ) @@ -325,7 +347,7 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"): "y", "e4", ], - version=voice_model_version, + version='v2', ) ] ) @@ -380,6 +402,11 @@ if __name__ == "__main__": # version = "v2" # export(vits_path, gpt_path, exp_path, version) + # gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" + # vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth" + # exp_path = "v2pro_export" + # version = "v2Pro" + gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth" exp_path = "v2proplus_export" diff --git a/playground/audio_prompt_hubert.npy b/playground/audio_prompt_hubert.npy new file mode 100644 index 00000000..414da85c Binary files /dev/null and b/playground/audio_prompt_hubert.npy differ diff --git a/playground/freerun.py b/playground/freerun.py new file mode 100644 index 00000000..6adcd8f4 --- /dev/null +++ b/playground/freerun.py @@ -0,0 +1,102 @@ +import onnxruntime as ort +import numpy as np +import onnx +from tqdm import tqdm +import torchaudio +import torch + +MODEL_PATH = "playground/v2proplus_export/v2proplus" + +def audio_postprocess( + audios, + fragment_interval: float = 0.3, +): + zero_wav = np.zeros((int(32000 * fragment_interval),)).astype(np.float32) + for i, audio in enumerate(audios): + max_audio = np.abs(audio).max() # 简单防止16bit爆音 + if max_audio > 1: + audio /= max_audio + audio = np.concatenate([audio, zero_wav], axis=0) + audios[i] = audio + + audio = np.concatenate(audios, axis=0) + + # audio = (audio * 32768).astype(np.int16) + + audio_tensor = torch.from_numpy(audio).unsqueeze(0) + + torchaudio.save('playground/output.wav', audio_tensor, 32000) + + return audio + + +input_phones = np.load("playground/input_phones.npy") +input_bert = np.load("playground/input_bert.npy").T.astype(np.float32) +ref_phones = np.load("playground/ref_phones.npy") +ref_bert = np.load("playground/ref_bert.npy").T.astype(np.float32) +audio_prompt_hubert = np.load("playground/audio_prompt_hubert.npy").astype(np.float32) + + +encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx") + +outputs = encoder.run(None, { + "text_seq": input_phones, + "text_bert": input_bert, + "ref_seq": ref_phones, + "ref_bert": ref_bert, + "ssl_content": audio_prompt_hubert +}) + +print(outputs[0].shape, outputs[1].shape) + +x = outputs[0] +prompts = outputs[1] + +fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx") +sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx") + +# for i in tqdm(range(10000)): +[y, k, v, y_emb, x_example] = fsdec.run(None, { + "x": x, + "prompts": prompts +}) + +early_stop_num = -1 +prefix_len = prompts.shape[1] + +stop = False +for idx in tqdm(range(1, 1500)): + # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] + [y, k, v, y_emb, logits, samples] = sdec.run(None, { + "iy": y, + "ik": k, + "iv": v, + "iy_emb": y_emb, + "ix_example": x_example + }) + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + stop = True + if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: + stop = True + if stop: + break +y[0, -1] = 0 + + +pred_semantic = np.expand_dims(y[:, -idx:], axis=0) +# Read and resample reference audio +waveform, sample_rate = torchaudio.load("playground/ref/audio.wav") +if sample_rate != 32000: + resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=32000) + waveform = resampler(waveform) +ref_audio = waveform.numpy().astype(np.float32) +vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx") + +[audio] = vtis.run(None, { + "text_seq": input_phones, + "pred_semantic": pred_semantic, + "ref_audio": ref_audio +}) +print(audio.shape, audio.dtype, audio.min(), audio.max()) + +audio_postprocess([audio]) diff --git a/playground/input_bert.npy b/playground/input_bert.npy new file mode 100644 index 00000000..05025a50 Binary files /dev/null and b/playground/input_bert.npy differ diff --git a/playground/input_phones.npy b/playground/input_phones.npy new file mode 100644 index 00000000..38fbc8d7 Binary files /dev/null and b/playground/input_phones.npy differ diff --git a/playground/output.wav b/playground/output.wav new file mode 100644 index 00000000..7ff11483 Binary files /dev/null and b/playground/output.wav differ diff --git a/playground/ref/audio.wav b/playground/ref/audio.wav new file mode 100644 index 00000000..78320d71 Binary files /dev/null and b/playground/ref/audio.wav differ diff --git a/playground/ref_bert.npy b/playground/ref_bert.npy new file mode 100644 index 00000000..88a8b647 Binary files /dev/null and b/playground/ref_bert.npy differ diff --git a/playground/ref_phones.npy b/playground/ref_phones.npy new file mode 100644 index 00000000..b8abf26e Binary files /dev/null and b/playground/ref_phones.npy differ