From da5aa78224c253138f61f68db19c80d2c8e822d8 Mon Sep 17 00:00:00 2001 From: zpeng11 Date: Wed, 20 Aug 2025 02:24:59 -0400 Subject: [PATCH] feat:combined fsdc and encoder, todo:extract audio pipeline --- GPT_SoVITS/onnx_export.py | 45 +++++++++++++++++---------------------- playground/freerun.py | 21 +++++++++--------- 2 files changed, 30 insertions(+), 36 deletions(-) diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index 48d64a7e..8b2fad11 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -94,6 +94,7 @@ class T2SEncoder(nn.Module): def __init__(self, t2s, vits): super().__init__() self.encoder = t2s.onnx_encoder + self.fsdc = t2s.first_stage_decoder self.vits = vits def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): @@ -103,7 +104,8 @@ class T2SEncoder(nn.Module): all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) bert = bert.unsqueeze(0) prompt = prompt_semantic.unsqueeze(0) - return self.encoder(all_phoneme_ids, bert), prompt + return self.fsdc(self.encoder(all_phoneme_ids, bert), prompt) + class T2SModel(nn.Module): @@ -130,20 +132,13 @@ class T2SModel(nn.Module): early_stop_num = self.t2s_model.early_stop_num # [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N] - x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) - - prefix_len = prompts.shape[1] - - # [1,N,512] [1,N] - y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) + y, k, v, y_emb, x_example = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) stop = False for idx in range(1, 1500): # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] enco = self.stage_decoder(y, k, v, y_emb, x_example) y, k, v, y_emb, logits, samples = enco - if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: - stop = True if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS: stop = True if stop: @@ -167,7 +162,7 @@ class T2SModel(nn.Module): (ref_seq, text_seq, ref_bert, text_bert, ssl_content), f"onnx/{project_name}/{project_name}_t2s_encoder.onnx", input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"], - output_names=["x", "prompts"], + output_names=["y", "k", "v", "y_emb", "x_example"], dynamic_axes={ "ref_seq": {1: "ref_length"}, "text_seq": {1: "text_length"}, @@ -177,22 +172,22 @@ class T2SModel(nn.Module): }, opset_version=16, ) - x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + y, k, v, y_emb, x_example = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) - torch.onnx.export( - self.first_stage_decoder, - (x, prompts), - f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx", - input_names=["x", "prompts"], - output_names=["y", "k", "v", "y_emb", "x_example"], - dynamic_axes={ - "x": {1: "x_length"}, - "prompts": {1: "prompts_length"}, - }, - verbose=False, - opset_version=16, - ) - y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) + # torch.onnx.export( + # self.first_stage_decoder, + # (x, prompts), + # f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx", + # input_names=["x", "prompts"], + # output_names=["y", "k", "v", "y_emb", "x_example"], + # dynamic_axes={ + # "x": {1: "x_length"}, + # "prompts": {1: "prompts_length"}, + # }, + # verbose=False, + # opset_version=16, + # ) + # y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) torch.onnx.export( self.stage_decoder, diff --git a/playground/freerun.py b/playground/freerun.py index 4377e44c..225395c6 100644 --- a/playground/freerun.py +++ b/playground/freerun.py @@ -7,7 +7,7 @@ import torch from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx -MODEL_PATH = "playground/v2proplus_export/v2proplus" +MODEL_PATH = "playground/v2pro_export/v2pro" def audio_postprocess( audios, @@ -49,7 +49,7 @@ def load_and_preprocess_audio(audio_path, max_length=160000): waveform = waveform[:, :max_length] # make a zero tensor that has length 3200*0.3 - zero_tensor = torch.zeros((1, 9600), dtype=torch.float32) + zero_tensor = torch.zeros((1, 4800), dtype=torch.float32) # concate zero_tensor with waveform waveform = torch.cat([waveform, zero_tensor], dim=1) @@ -75,7 +75,7 @@ def preprocess_text(text:str): # input_phones_saved = np.load("playground/ref/input_phones.npy") # input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32) -[input_phones, input_bert] = preprocess_text("震撼视角,感受成都世运会,闭幕式烟花") +[input_phones, input_bert] = preprocess_text("震撼视角,感受开幕式,闭幕式烟花") # ref_phones = np.load("playground/ref/ref_phones.npy") @@ -88,7 +88,7 @@ audio_prompt_hubert = get_audio_hubert("playground/ref/audio.wav") encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx") -[x, prompts] = encoder.run(None, { +[y, k, v, y_emb, x_example] = encoder.run(None, { "text_seq": input_phones, "text_bert": input_bert, "ref_seq": ref_phones, @@ -96,18 +96,16 @@ encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx") "ssl_content": audio_prompt_hubert }) -fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx") +# fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx") sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx") # for i in tqdm(range(10000)): -[y, k, v, y_emb, x_example] = fsdec.run(None, { - "x": x, - "prompts": prompts -}) +# [y, k, v, y_emb, x_example] = fsdec.run(None, { +# "x": x, +# "prompts": prompts +# }) -prefix_len = prompts.shape[1] - stop = False for idx in tqdm(range(1, 1500)): # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] @@ -129,6 +127,7 @@ waveform, sample_rate = torchaudio.load("playground/ref/audio.wav") if sample_rate != 32000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=32000) waveform = resampler(waveform) +print(f"Waveform shape: {waveform.shape}") ref_audio = waveform.numpy().astype(np.float32) vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx")