diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index 70a5aee7..41408059 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -133,14 +133,11 @@ class T2SModel(nn.Module): # [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N] y, k, v, y_emb, x_example = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) - stop = False - for idx in tqdm(range(1, 1500)): + for idx in tqdm(range(1, 20)): # This is a fake one! do take this as reference # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] enco = self.stage_decoder(y, k, v, y_emb, x_example) y, k, v, y_emb, logits, samples = enco if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS: - stop = True - if stop: break y[0, -1] = 0 @@ -216,9 +213,7 @@ class VitsModel(nn.Module): else: self.hps["model"]["version"] = version - self.sv_model = None - if version == "v2ProPlus" or version == "v2Pro": - self.sv_model = SV("cpu", False) + self.is_v2p = version.lower() in ['v2pro', 'v2proplus'] self.hps = DictToAttrRecursive(self.hps) self.hps.model.semantic_frame_rate = "25hz" @@ -231,11 +226,11 @@ class VitsModel(nn.Module): self.vq_model.eval() self.vq_model.load_state_dict(dict_s2["weight"], strict=False) #filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048 - def forward(self, text_seq, pred_semantic, ref_audio, spectrum): - if self.sv_model is not None: - sv_emb=self.sv_model.compute_embedding3_onnx(resample_audio(ref_audio, 32000, 16000)) + def forward(self, text_seq, pred_semantic, spectrum, sv_emb): + if self.is_v2p: return self.vq_model(pred_semantic, text_seq, spectrum, sv_emb=sv_emb)[0, 0] - return self.vq_model(pred_semantic, text_seq, spectrum)[0, 0] + else: + return self.vq_model(pred_semantic, text_seq, spectrum)[0, 0] class GptSoVits(nn.Module): @@ -244,19 +239,19 @@ class GptSoVits(nn.Module): self.vits = vits self.t2s = t2s - def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, spectrum, ssl_content): + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb): pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) - audio = self.vits(text_seq, pred_semantic, ref_audio, spectrum) + audio = self.vits(text_seq, pred_semantic, spectrum, sv_emb) return audio - def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, spectrum, ssl_content, project_name): + def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, project_name): self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name) pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) torch.onnx.export( self.vits, - (text_seq, pred_semantic, ref_audio, spectrum), + (text_seq, pred_semantic, spectrum, sv_emb), f"onnx/{project_name}/{project_name}_vits.onnx", - input_names=["text_seq", "pred_semantic", "ref_audio", "spectrum"], + input_names=["text_seq", "pred_semantic", "spectrum", "sv_emb"], output_names=["audio"], dynamic_axes={ "text_seq": {1: "text_length"}, @@ -269,7 +264,7 @@ class GptSoVits(nn.Module): ) -class HuBertSSLModel(nn.Module): +class AudioPreprocess(nn.Module): def __init__(self): super().__init__() self.config = HubertConfig.from_pretrained(cnhubert_base_path) @@ -285,6 +280,8 @@ class HuBertSSLModel(nn.Module): ) self.model.eval() + self.sv_model = SV("cpu", False) + def forward(self, ref_audio_32k): spectrum = spectrogram_torch( ref_audio_32k, @@ -294,22 +291,24 @@ class HuBertSSLModel(nn.Module): 2048, center=False, ) + ref_audio_16k = resample_audio(ref_audio_32k, 32000, 16000) + sv_emb = self.sv_model.compute_embedding3_onnx(ref_audio_16k) - ref_audio_16k = resample_audio(ref_audio_32k, 32000, 16000).unsqueeze(0) zero_tensor = torch.zeros((1, 4800), dtype=torch.float32) + ref_audio_16k = ref_audio_16k.unsqueeze(0) # concate zero_tensor with waveform ref_audio_16k = torch.cat([ref_audio_16k, zero_tensor], dim=1) ssl_content = self.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2) - return ssl_content, spectrum + return ssl_content, spectrum, sv_emb def export(vits_path, gpt_path, project_name, voice_model_version="v2"): vits = VitsModel(vits_path, version=voice_model_version) gpt = T2SModel(gpt_path, vits) gpt_sovits = GptSoVits(vits, gpt) - ssl = HuBertSSLModel() + preprocessor = AudioPreprocess() ref_seq = torch.LongTensor( [ cleaned_text_to_sequence( @@ -376,19 +375,19 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"): except: pass - torch.onnx.export(ssl, (ref_audio32k,), f"onnx/{project_name}/{project_name}_hubertssl.onnx", + torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx", input_names=["audio32k"], - output_names=["hubert_ssl_output", "spectrum"], + output_names=["hubert_ssl_output", "spectrum", "sv_emb"], dynamic_axes={ - "audio32k": {0: "batch_size", 1: "sequence_length"}, - "hubert_ssl_output": {0: "batch_size", 2: "hubert_length"}, - "spectrum": {0: "batch_size", 2: "spectrum_length"} + "audio32k": {1: "sequence_length"}, + "hubert_ssl_output": {2: "hubert_length"}, + "spectrum": {2: "spectrum_length"} }) - [ssl_content, spectrum] = ssl(ref_audio32k) - gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio32k, spectrum.float(), ssl_content.float()) + [ssl_content, spectrum, sv_emb] = preprocessor(ref_audio32k) + gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float()) # exit() - gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio32k, spectrum.float(), ssl_content.float(), project_name) + gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), project_name) if voice_model_version == "v1": symbols = symbols_v1 @@ -401,11 +400,11 @@ if __name__ == "__main__": except: pass - # gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" - # vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" - # exp_path = "v2_export" - # version = "v2" - # export(vits_path, gpt_path, exp_path, version) + gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" + vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" + exp_path = "v2_export" + version = "v2" + export(vits_path, gpt_path, exp_path, version) gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth" @@ -413,10 +412,10 @@ if __name__ == "__main__": version = "v2Pro" export(vits_path, gpt_path, exp_path, version) - # gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" - # vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth" - # exp_path = "v2proplus_export" - # version = "v2ProPlus" - # export(vits_path, gpt_path, exp_path, version) + gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" + vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth" + exp_path = "v2proplus_export" + version = "v2ProPlus" + export(vits_path, gpt_path, exp_path, version) diff --git a/playground/freerun.py b/playground/freerun.py index a9b8ed14..ab21410f 100644 --- a/playground/freerun.py +++ b/playground/freerun.py @@ -7,7 +7,7 @@ import torch from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx -MODEL_PATH = "onnx/v2pro_export/v2pro" +MODEL_PATH = "onnx/v2proplus_export/v2proplus" def audio_postprocess( audios, @@ -31,42 +31,7 @@ def audio_postprocess( return audio -# def load_and_preprocess_audio(audio_path, max_length=160000): -# """Load and preprocess audio file to 16k""" -# waveform, sample_rate = torchaudio.load(audio_path) - -# # Resample to 16kHz if needed -# if sample_rate != 16000: -# resampler = torchaudio.transforms.Resample(sample_rate, 16000) -# waveform = resampler(waveform) - -# # Take first channel -# if waveform.shape[0] > 1: -# waveform = waveform[0:1] - -# # Limit length for testing (10 seconds at 16kHz) -# if waveform.shape[1] > max_length: -# waveform = waveform[:, :max_length] - -# # make a zero tensor that has length 3200*0.3 -# zero_tensor = torch.zeros((1, 4800), dtype=torch.float32) - -# # concate zero_tensor with waveform -# waveform = torch.cat([waveform, zero_tensor], dim=1) - -# return waveform - -# def get_audio_hubert(audio_path): -# """Get HuBERT features for the audio file""" -# waveform = load_and_preprocess_audio(audio_path) -# ort_session = ort.InferenceSession("playground/hubert/chinese-hubert-base.onnx") -# ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()} -# hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32) -# # transpose axis 1 and 2 with numpy -# hubert_feature = hubert_feature.transpose(0, 2, 1) -# return hubert_feature - -def load_and_preprocess_audio(audio_path): +def load_audio(audio_path): """Load and preprocess audio file to 32k""" waveform, sample_rate = torchaudio.load(audio_path) @@ -81,15 +46,13 @@ def load_and_preprocess_audio(audio_path): return waveform -def get_audio_hubert(audio_path): +def audio_preprocess(audio_path): """Get HuBERT features for the audio file""" - waveform = load_and_preprocess_audio(audio_path) - ort_session = ort.InferenceSession(MODEL_PATH + "_export_hubertssl.onnx") + waveform = load_audio(audio_path) + ort_session = ort.InferenceSession(MODEL_PATH + "_export_audio_preprocess.onnx") ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()} - [hubert_feature, spectrum] = ort_session.run(None, ort_inputs) - # transpose axis 1 and 2 with numpy - # hubert_feature = hubert_feature.transpose(0, 2, 1) - return hubert_feature, spectrum + [hubert_feature, spectrum, sv_emb] = ort_session.run(None, ort_inputs) + return hubert_feature, spectrum, sv_emb def preprocess_text(text:str): preprocessor = TextPreprocessorOnnx("playground/bert") @@ -108,7 +71,7 @@ def preprocess_text(text:str): [ref_phones, ref_bert] = preprocess_text("今日江苏苏州荷花市集开张热闹与浪漫交织") -[audio_prompt_hubert, spectrum] = get_audio_hubert("playground/ref/audio.wav") +[audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav") encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx") @@ -131,7 +94,6 @@ sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx") # }) -stop = False for idx in tqdm(range(1, 1500)): # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] [y, k, v, y_emb, logits, samples] = sdec.run(None, { @@ -160,8 +122,8 @@ vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx") [audio] = vtis.run(None, { "text_seq": input_phones, "pred_semantic": pred_semantic, - "ref_audio": ref_audio, - "spectrum": spectrum.astype(np.float32) + "spectrum": spectrum.astype(np.float32), + "sv_emb": sv_emb.astype(np.float32) }) audio_postprocess([audio])