From c85ee3d5212dd9255aa85e4407341540e6384385 Mon Sep 17 00:00:00 2001 From: zpeng11 Date: Mon, 25 Aug 2025 17:57:04 -0400 Subject: [PATCH] feat:successfully unified first step and following step --- GPT_SoVITS/AR/models/t2s_model_onnx.py | 164 ++++++------------------- 1 file changed, 38 insertions(+), 126 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model_onnx.py b/GPT_SoVITS/AR/models/t2s_model_onnx.py index 58c20263..203cbfc4 100644 --- a/GPT_SoVITS/AR/models/t2s_model_onnx.py +++ b/GPT_SoVITS/AR/models/t2s_model_onnx.py @@ -105,110 +105,6 @@ class OnnxEncoder(nn.Module): x = x + self.bert_proj(bert_feature.transpose(1, 2)) return self.ar_text_position(x) - -class T2SFirstStageDecoder(nn.Module): - def __init__( - self, - ar_audio_embedding, - ar_audio_position, - h, - ar_predict_layer, - loss_fct, - ar_accuracy_metric, - early_stop_num, - num_layers, - ): - super().__init__() - self.ar_audio_embedding = ar_audio_embedding - self.ar_audio_position = ar_audio_position - self.h = h - self.ar_predict_layer = ar_predict_layer - self.loss_fct = loss_fct - self.ar_accuracy_metric = ar_accuracy_metric - self.early_stop_num = early_stop_num - self.num_layers = num_layers - - def forward(self, x, prompt, y_emb, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None, x_seq_len = None, y_seq_len = None): - if top_k is None: - top_k = torch.LongTensor([15]).to(device=x.device) - if top_p is None: - top_p = torch.FloatTensor([1.0]).to(device=x.device) - if repetition_penalty is None: - repetition_penalty = torch.FloatTensor([1.0]).to(device=x.device) - if temperature is None: - temperature = torch.FloatTensor([1.0]).to(device=x.device) - minus_one = torch.tensor([-1]).to(x.device).to(torch.int64) - - y = prompt - x_example = x[:, :, 0] * 0.0 - # N, 1, 512 - cache = { - "all_stage": self.num_layers, - "k": None, - "v": None, - "y_emb": y_emb, - "first_infer": first_infer, - "stage": 0, - } - - # 运行时判断对最后一个y还是整个y做embedding,以正确应对首次和后续 - multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(y)[1] - index_offset = torch.min(minus_one, multipled) - y_to_emb = y[:, index_offset:] - - print("y_emb shape:", y_emb.shape) - - y_emb = torch.cat( - [ - cache["y_emb"], - self.ar_audio_embedding(y_to_emb), - ], - 1, - ) - cache["y_emb"] = y_emb - - y_pos = self.ar_audio_position(y_emb) - - xy_pos = torch.concat([x, y_pos], dim=1) - - # 运行时判断对最后一个xy_pos还是整个xy_pos做self attention - multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(xy_pos)[1] - index_offset = torch.min(minus_one, multipled) - xy_pos = xy_pos[:, index_offset:] - - # 构造xy的attention mask - x_attn_mask = torch.zeros((x_seq_len, x_seq_len)).bool() - y_attn_mask = torch.ones((y_seq_len, y_seq_len)).to(torch.int64) - y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum( - torch.ones( - (y_seq_len, 1), - dtype=torch.int64, - ), - dim=0, - ) - y_attn_mask = y_attn_mask > 0 - - x_y_pad = torch.ones((x_seq_len, y_seq_len)).to(torch.bool) - y_x_pad = torch.zeros((y_seq_len, x_seq_len)).to(torch.bool) - - x_attn_mask_pad = torch.cat([x_attn_mask, x_y_pad], dim=1) - y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1) - xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) - print("first iter xy_attn_mask shape:", xy_attn_mask.shape) - cache["k"] = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float) - cache["v"] = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float) - - print("first iter cache k shape:", cache["k"].shape, 'cache v shape:', cache["v"].shape) - - xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) - logits = self.ar_predict_layer(xy_dec[:, -1]) - samples = sample(logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)[0].unsqueeze(0) - - y = torch.concat([y, samples], dim=1) - - return y, cache["k"], cache["v"], cache["y_emb"], x_example - - class T2SStageDecoder(nn.Module): def __init__( self, @@ -231,7 +127,7 @@ class T2SStageDecoder(nn.Module): self.early_stop_num = early_stop_num self.num_layers = num_layers - def forward(self, x, y, k, v, y_emb, x_example, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None): + def forward(self, x, y, k, v, y_emb, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None, x_seq_len = None, y_seq_len = None): if top_k is None: top_k = torch.LongTensor([15]).to(device=y.device) if top_p is None: @@ -244,8 +140,8 @@ class T2SStageDecoder(nn.Module): cache = { "all_stage": self.num_layers, - "k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)), - "v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)), + "k": k, + "v": v, "y_emb": y_emb, "first_infer": first_infer, "stage": 0, @@ -273,10 +169,29 @@ class T2SStageDecoder(nn.Module): index_offset = torch.min(minus_one, multipled) xy_pos = xy_pos[:, index_offset:] - x_example_len = torch.onnx.operators.shape_as_tensor(x_example)[1] - y_example_len = torch.onnx.operators.shape_as_tensor(y_pos)[1] - xy_attn_mask = torch.zeros((1, x_example_len + y_example_len), dtype=torch.bool) - print('xy_attn_mask shape:', xy_attn_mask.shape) + # 构造xy的attention mask + x_attn_mask = torch.zeros((x_seq_len, x_seq_len)).bool() + y_attn_mask = torch.ones((y_seq_len, y_seq_len)).to(torch.int64) + y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum( + torch.ones( + (y_seq_len, 1), + dtype=torch.int64, + ), + dim=0, + ) + y_attn_mask = y_attn_mask > 0 + + x_y_pad = torch.ones((x_seq_len, y_seq_len)).to(torch.bool) + y_x_pad = torch.zeros((y_seq_len, x_seq_len)).to(torch.bool) + + x_attn_mask_pad = torch.cat([x_attn_mask, x_y_pad], dim=1) + y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) + + # 运行时判断attension mask使用最后一个还是整个 + multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(xy_attn_mask)[0] + index_offset = torch.min(minus_one, multipled) + xy_attn_mask = xy_attn_mask[index_offset:, :] xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) logits = self.ar_predict_layer(xy_dec[:, -1]) @@ -332,16 +247,6 @@ class Text2SemanticDecoder(nn.Module): def init_onnx(self): self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position) - self.first_stage_decoder = T2SFirstStageDecoder( - self.ar_audio_embedding, - self.ar_audio_position, - self.h, - self.ar_predict_layer, - self.loss_fct, - self.ar_accuracy_metric, - self.early_stop_num, - self.num_layers, - ) self.stage_decoder = T2SStageDecoder( self.ar_audio_embedding, self.ar_audio_position, @@ -362,14 +267,21 @@ class Text2SemanticDecoder(nn.Module): prefix_len = prompts.shape[1] x = self.onnx_encoder(x, bert_feature) - y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts, torch.empty((1,0,512)).to(torch.float), top_k=top_k, - first_infer=torch.LongTensor([1]), - x_seq_len=x.shape[1], y_seq_len=prompts.shape[1]) + + init_k = torch.zeros((self.num_layers, (x.shape[1] + prompts.shape[1]), 1, 512), dtype=torch.float) + init_v = torch.zeros((self.num_layers, (x.shape[1] + prompts.shape[1]), 1, 512), dtype=torch.float) + + y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompts, init_k, init_v, + torch.empty((1,0,512)).to(torch.float), top_k=top_k, + first_infer=torch.LongTensor([1]), + x_seq_len=x.shape[1], y_seq_len=prompts.shape[1]) stop = False for idx in tqdm(range(1, 1500)): - enco = self.stage_decoder( torch.empty((1,0,512)).to(torch.float) ,y, k, v, y_emb, x_example, top_k=top_k, - first_infer=torch.LongTensor([0])) + k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)) + v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)) + enco = self.stage_decoder( torch.empty((1,0,512)).to(torch.float) ,y, k, v, y_emb, top_k=top_k, + first_infer=torch.LongTensor([0]), x_seq_len=x.shape[1], y_seq_len=y.shape[1]) y, k, v, y_emb, logits, samples = enco if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: stop = True