From b45cbc3561d85190e0061d0395d52e1dddb7d1ca Mon Sep 17 00:00:00 2001 From: zpeng11 Date: Sat, 23 Aug 2025 13:03:02 -0400 Subject: [PATCH] feat: sampling params working now for export, todo:fold weights clean code --- GPT_SoVITS/AR/models/t2s_model_onnx.py | 2 +- GPT_SoVITS/onnx_export.py | 87 +++++++++++++------------- playground/freerun.py | 23 +++++-- 3 files changed, 62 insertions(+), 50 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model_onnx.py b/GPT_SoVITS/AR/models/t2s_model_onnx.py index 65c95d48..d0b50449 100644 --- a/GPT_SoVITS/AR/models/t2s_model_onnx.py +++ b/GPT_SoVITS/AR/models/t2s_model_onnx.py @@ -61,7 +61,7 @@ def logits_to_probs( ) logits = logits.masked_fill(indices_to_remove, -float("Inf")) - logits = logits / max(temperature, 1e-5) + logits = logits / torch.max(temperature, torch.tensor(1e-5, device=temperature.device, dtype=temperature.dtype)) # if top_k is not None: # To be captured by onnx v, _ = torch.topk(logits, top_k) diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index 87ea5e70..5858856e 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -94,14 +94,14 @@ class T2SInitStep(nn.Module): self.fsdc = t2s.first_stage_decoder self.vits = vits - def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None): codes = self.vits.extract_latent(ssl_content) prompt_semantic = codes[0, 0] bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1) all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) bert = bert.unsqueeze(0) prompt = prompt_semantic.unsqueeze(0) - [y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt) + [y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) fake_logits = torch.randn((1, 1025), dtype=torch.float32) # Dummy logits for ONNX export fake_samples = torch.zeros((1, 1), dtype=torch.int32) # Dummy samples for ONNX export return y, k, v, y_emb, x_example, fake_logits, fake_samples @@ -111,8 +111,8 @@ class T2SStageStep(nn.Module): super().__init__() self.stage_decoder = stage_decoder - def forward(self, iy, ik, iv, iy_emb, ix_example): - [y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example) + def forward(self, iy, ik, iv, iy_emb, ix_example, top_k=None, top_p=None, repetition_penalty=None, temperature=None): + [y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) fake_x_example = torch.randn((1, 512), dtype=torch.float32) # Dummy x_example for ONNX export return y, k, v, y_emb, fake_x_example, logits, samples @@ -136,38 +136,27 @@ class T2SModel(nn.Module): self.stage_decoder = self.t2s_model.stage_decoder # self.t2s_model = torch.jit.script(self.t2s_model) - def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None): early_stop_num = self.t2s_model.early_stop_num # [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N] - y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) for idx in tqdm(range(1, 20)): # This is a fake one! do take this as reference # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] - enco = self.stage_decoder(y, k, v, y_emb, x_example) + enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) y, k, v, y_emb, logits, samples = enco - print(logits.shape, samples.shape) if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS: break - y[0, -1] = 0 return y[:, -idx:].unsqueeze(0) - def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False): - # self.init_step = torch.jit.script(self.init_step) - if dynamo: - export_options = torch.onnx.ExportOptions(dynamic_shapes=True) - init_step_export_output = torch.onnx.dynamo_export( - self.init_step, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options - ) - init_step_export_output.save(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx") - return - + def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None): torch.onnx.export( self.init_step, - (ref_seq, text_seq, ref_bert, text_bert, ssl_content), + (ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k, top_p, repetition_penalty, temperature), f"onnx/{project_name}/{project_name}_t2s_init_step.onnx", - input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content"], + input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content", "top_k", "top_p", "repetition_penalty", "temperature"], output_names=["y", "k", "v", "y_emb", "x_example", 'logits', 'samples'], dynamic_axes={ "ref_text_phones": {1: "ref_length"}, @@ -178,14 +167,14 @@ class T2SModel(nn.Module): }, opset_version=16, ) - y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) stage_step = T2SStageStep(self.stage_decoder) torch.onnx.export( stage_step, - (y, k, v, y_emb, x_example), + (y, k, v, y_emb, x_example, top_k, top_p, repetition_penalty, temperature), f"onnx/{project_name}/{project_name}_t2s_sdec.onnx", - input_names=["iy", "ik", "iv", "iy_emb", "ix_example"], + input_names=["iy", "ik", "iv", "iy_emb", "ix_example", "top_k", "top_p", "repetition_penalty", "temperature"], output_names=["y", "k", "v", "y_emb","x_example", "logits", "samples"], dynamic_axes={ "iy": {1: "iy_length"}, @@ -237,14 +226,14 @@ class GptSoVits(nn.Module): self.vits = vits self.t2s = t2s - def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb): - pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, top_k=None, top_p=None, repetition_penalty=None, temperature=None): + pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) audio = self.vits(text_seq, pred_semantic, spectrum, sv_emb) return audio - def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, project_name): - self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name) - pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None): + self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) + pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) torch.onnx.export( self.vits, (text_seq, pred_semantic, spectrum, sv_emb), @@ -304,6 +293,14 @@ def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path, combi data_inputs_init = [input for input in init_step_model.graph.input] data_inputs_stage = [input for input in stage_step_model.graph.input] + # Get all names from both lists + names_list_init = {obj.name for obj in data_inputs_init} + names_list_stage = {obj.name for obj in data_inputs_stage} + # Find names that appear in both lists + repeated_input_names = names_list_init.intersection(names_list_stage) + # Filter out objects with repeated names + data_inputs_stage = [obj for obj in data_inputs_stage if obj.name not in repeated_input_names] + del then_graph.input[:] del else_graph.input[:] @@ -413,13 +410,17 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"): ref_bert = torch.randn((ref_seq.shape[1], 1024)).float() text_bert = torch.randn((text_seq.shape[1], 1024)).float() ref_audio32k = torch.randn((1, 32000 * 5)).float() - 0.5 + top_k = torch.LongTensor([15]) + top_p = torch.FloatTensor([1.0]) + repetition_penalty = torch.FloatTensor([1.0]) + temperature = torch.FloatTensor([1.0]) os.makedirs(f"onnx/{project_name}", exist_ok=True) [ssl_content, spectrum, sv_emb] = preprocessor(ref_audio32k) - gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float()) + gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) # exit() - gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), project_name) + gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx", input_names=["audio32k"], @@ -444,12 +445,12 @@ if __name__ == "__main__": # version = "v1" # export(vits_path, gpt_path, exp_path, version) - gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" - vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" - exp_path = "v2_export" - version = "v2" - export(vits_path, gpt_path, exp_path, version) - combineInitStepAndStageStep('onnx/v2_export/v2_export_t2s_init_step.onnx', 'onnx/v2_export/v2_export_t2s_sdec.onnx', 'onnx/v2_export/v2_export_t2s_combined.onnx') + # gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" + # vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" + # exp_path = "v2_export" + # version = "v2" + # export(vits_path, gpt_path, exp_path, version) + # combineInitStepAndStageStep('onnx/v2_export/v2_export_t2s_init_step.onnx', 'onnx/v2_export/v2_export_t2s_sdec.onnx', 'onnx/v2_export/v2_export_t2s_combined.onnx') # gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt" # vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth" @@ -457,11 +458,11 @@ if __name__ == "__main__": # version = "v2Pro" # export(vits_path, gpt_path, exp_path, version) - # gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" - # vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth" - # exp_path = "v2proplus_export" - # version = "v2ProPlus" - # export(vits_path, gpt_path, exp_path, version) - # combineInitStepAndStageStep('onnx/v2proplus_export/v2proplus_export_t2s_init_step.onnx', 'onnx/v2proplus_export/v2proplus_export_t2s_sdec.onnx', 'onnx/v2proplus_export/v2proplus_export_t2s_combined.onnx') + gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" + vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth" + exp_path = "v2proplus_export" + version = "v2ProPlus" + export(vits_path, gpt_path, exp_path, version) + combineInitStepAndStageStep('onnx/v2proplus_export/v2proplus_export_t2s_init_step.onnx', 'onnx/v2proplus_export/v2proplus_export_t2s_sdec.onnx', 'onnx/v2proplus_export/v2proplus_export_t2s_combined.onnx') diff --git a/playground/freerun.py b/playground/freerun.py index 18736c93..d290ccb6 100644 --- a/playground/freerun.py +++ b/playground/freerun.py @@ -7,7 +7,7 @@ import torch from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx -MODEL_PATH = "onnx/v2_export/v2" +MODEL_PATH = "onnx/v2proplus_export/v2proplus" def audio_postprocess( audios, @@ -73,10 +73,13 @@ def preprocess_text(text:str): [audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav") -np.save("playground/ref/audio_prompt_hubert.npy", audio_prompt_hubert.astype(np.float16)) - # audio_prompt_hubert_saved = np.load("playground/ref/audio_prompt_hubert.npy").astype(np.float32) +top_k = np.array([15], dtype=np.int64) +top_p = np.array([1.0], dtype=np.float32) +repetition_penalty = np.array([1.0], dtype=np.float32) +temperature = np.array([1.0], dtype=np.float32) + t2s_combined = ort.InferenceSession(MODEL_PATH+"_export_t2s_combined.onnx") # t2s_init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx") @@ -91,7 +94,11 @@ t2s_combined = ort.InferenceSession(MODEL_PATH+"_export_t2s_combined.onnx") "ik":np.empty((24, 0, 1, 512), dtype=np.float32), "iv":np.empty((24, 0, 1, 512), dtype=np.float32), "iy_emb":np.empty((1, 0, 512), dtype=np.float32), - "ix_example":np.empty((1, 0), dtype=np.float32) + "ix_example":np.empty((1, 0), dtype=np.float32), + "top_k": top_k, + "top_p": top_p, + "repetition_penalty": repetition_penalty, + "temperature": temperature }) # t2s_stage_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx") @@ -109,7 +116,11 @@ for idx in tqdm(range(1, 1500)): "ik": k, "iv": v, "iy_emb": y_emb, - "ix_example": x_example + "ix_example": x_example, + "top_k": top_k, + "top_p": top_p, + "repetition_penalty": repetition_penalty, + "temperature": temperature }) if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token break @@ -124,7 +135,7 @@ vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx") "input_text_phones": input_phones, "pred_semantic": pred_semantic, "spectrum": spectrum.astype(np.float32), - # "sv_emb": sv_emb.astype(np.float32) + "sv_emb": sv_emb.astype(np.float32) }) audio_postprocess([audio])