feat:combined fsdc and encoder, todo:extract audio pipeline

This commit is contained in:
zpeng11 2025-08-20 02:24:59 -04:00
parent 71cbe28e68
commit da5aa78224
2 changed files with 30 additions and 36 deletions

View File

@ -94,6 +94,7 @@ class T2SEncoder(nn.Module):
def __init__(self, t2s, vits): def __init__(self, t2s, vits):
super().__init__() super().__init__()
self.encoder = t2s.onnx_encoder self.encoder = t2s.onnx_encoder
self.fsdc = t2s.first_stage_decoder
self.vits = vits self.vits = vits
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
@ -103,7 +104,8 @@ class T2SEncoder(nn.Module):
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
bert = bert.unsqueeze(0) bert = bert.unsqueeze(0)
prompt = prompt_semantic.unsqueeze(0) prompt = prompt_semantic.unsqueeze(0)
return self.encoder(all_phoneme_ids, bert), prompt return self.fsdc(self.encoder(all_phoneme_ids, bert), prompt)
class T2SModel(nn.Module): class T2SModel(nn.Module):
@ -130,20 +132,13 @@ class T2SModel(nn.Module):
early_stop_num = self.t2s_model.early_stop_num early_stop_num = self.t2s_model.early_stop_num
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N] # [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) y, k, v, y_emb, x_example = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
prefix_len = prompts.shape[1]
# [1,N,512] [1,N]
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
stop = False stop = False
for idx in range(1, 1500): for idx in range(1, 1500):
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
enco = self.stage_decoder(y, k, v, y_emb, x_example) enco = self.stage_decoder(y, k, v, y_emb, x_example)
y, k, v, y_emb, logits, samples = enco y, k, v, y_emb, logits, samples = enco
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS: if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
stop = True stop = True
if stop: if stop:
@ -167,7 +162,7 @@ class T2SModel(nn.Module):
(ref_seq, text_seq, ref_bert, text_bert, ssl_content), (ref_seq, text_seq, ref_bert, text_bert, ssl_content),
f"onnx/{project_name}/{project_name}_t2s_encoder.onnx", f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"], input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
output_names=["x", "prompts"], output_names=["y", "k", "v", "y_emb", "x_example"],
dynamic_axes={ dynamic_axes={
"ref_seq": {1: "ref_length"}, "ref_seq": {1: "ref_length"},
"text_seq": {1: "text_length"}, "text_seq": {1: "text_length"},
@ -177,22 +172,22 @@ class T2SModel(nn.Module):
}, },
opset_version=16, opset_version=16,
) )
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) y, k, v, y_emb, x_example = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
torch.onnx.export( # torch.onnx.export(
self.first_stage_decoder, # self.first_stage_decoder,
(x, prompts), # (x, prompts),
f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx", # f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
input_names=["x", "prompts"], # input_names=["x", "prompts"],
output_names=["y", "k", "v", "y_emb", "x_example"], # output_names=["y", "k", "v", "y_emb", "x_example"],
dynamic_axes={ # dynamic_axes={
"x": {1: "x_length"}, # "x": {1: "x_length"},
"prompts": {1: "prompts_length"}, # "prompts": {1: "prompts_length"},
}, # },
verbose=False, # verbose=False,
opset_version=16, # opset_version=16,
) # )
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) # y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
torch.onnx.export( torch.onnx.export(
self.stage_decoder, self.stage_decoder,

View File

@ -7,7 +7,7 @@ import torch
from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx
MODEL_PATH = "playground/v2proplus_export/v2proplus" MODEL_PATH = "playground/v2pro_export/v2pro"
def audio_postprocess( def audio_postprocess(
audios, audios,
@ -49,7 +49,7 @@ def load_and_preprocess_audio(audio_path, max_length=160000):
waveform = waveform[:, :max_length] waveform = waveform[:, :max_length]
# make a zero tensor that has length 3200*0.3 # make a zero tensor that has length 3200*0.3
zero_tensor = torch.zeros((1, 9600), dtype=torch.float32) zero_tensor = torch.zeros((1, 4800), dtype=torch.float32)
# concate zero_tensor with waveform # concate zero_tensor with waveform
waveform = torch.cat([waveform, zero_tensor], dim=1) waveform = torch.cat([waveform, zero_tensor], dim=1)
@ -75,7 +75,7 @@ def preprocess_text(text:str):
# input_phones_saved = np.load("playground/ref/input_phones.npy") # input_phones_saved = np.load("playground/ref/input_phones.npy")
# input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32) # input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32)
[input_phones, input_bert] = preprocess_text("震撼视角,感受成都世运会,闭幕式烟花") [input_phones, input_bert] = preprocess_text("震撼视角,感受开幕式,闭幕式烟花")
# ref_phones = np.load("playground/ref/ref_phones.npy") # ref_phones = np.load("playground/ref/ref_phones.npy")
@ -88,7 +88,7 @@ audio_prompt_hubert = get_audio_hubert("playground/ref/audio.wav")
encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx") encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx")
[x, prompts] = encoder.run(None, { [y, k, v, y_emb, x_example] = encoder.run(None, {
"text_seq": input_phones, "text_seq": input_phones,
"text_bert": input_bert, "text_bert": input_bert,
"ref_seq": ref_phones, "ref_seq": ref_phones,
@ -96,18 +96,16 @@ encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx")
"ssl_content": audio_prompt_hubert "ssl_content": audio_prompt_hubert
}) })
fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx") # fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx")
sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx") sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
# for i in tqdm(range(10000)): # for i in tqdm(range(10000)):
[y, k, v, y_emb, x_example] = fsdec.run(None, { # [y, k, v, y_emb, x_example] = fsdec.run(None, {
"x": x, # "x": x,
"prompts": prompts # "prompts": prompts
}) # })
prefix_len = prompts.shape[1]
stop = False stop = False
for idx in tqdm(range(1, 1500)): for idx in tqdm(range(1, 1500)):
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
@ -129,6 +127,7 @@ waveform, sample_rate = torchaudio.load("playground/ref/audio.wav")
if sample_rate != 32000: if sample_rate != 32000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=32000) resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=32000)
waveform = resampler(waveform) waveform = resampler(waveform)
print(f"Waveform shape: {waveform.shape}")
ref_audio = waveform.numpy().astype(np.float32) ref_audio = waveform.numpy().astype(np.float32)
vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx") vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx")