diff --git a/GPT_SoVITS/AR/models/t2s_model_onnx.py b/GPT_SoVITS/AR/models/t2s_model_onnx.py index 5050d78d..58c20263 100644 --- a/GPT_SoVITS/AR/models/t2s_model_onnx.py +++ b/GPT_SoVITS/AR/models/t2s_model_onnx.py @@ -128,7 +128,7 @@ class T2SFirstStageDecoder(nn.Module): self.early_stop_num = early_stop_num self.num_layers = num_layers - def forward(self, x, prompt, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None): + def forward(self, x, prompt, y_emb, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None, x_seq_len = None, y_seq_len = None): if top_k is None: top_k = torch.LongTensor([15]).to(device=x.device) if top_p is None: @@ -137,6 +137,7 @@ class T2SFirstStageDecoder(nn.Module): repetition_penalty = torch.FloatTensor([1.0]).to(device=x.device) if temperature is None: temperature = torch.FloatTensor([1.0]).to(device=x.device) + minus_one = torch.tensor([-1]).to(x.device).to(torch.int64) y = prompt x_example = x[:, :, 0] * 0.0 @@ -145,45 +146,59 @@ class T2SFirstStageDecoder(nn.Module): "all_stage": self.num_layers, "k": None, "v": None, - "y_emb": None, + "y_emb": y_emb, "first_infer": first_infer, "stage": 0, } - y_emb = self.ar_audio_embedding(y) + # 运行时判断对最后一个y还是整个y做embedding,以正确应对首次和后续 + multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(y)[1] + index_offset = torch.min(minus_one, multipled) + y_to_emb = y[:, index_offset:] + print("y_emb shape:", y_emb.shape) + + y_emb = torch.cat( + [ + cache["y_emb"], + self.ar_audio_embedding(y_to_emb), + ], + 1, + ) cache["y_emb"] = y_emb + y_pos = self.ar_audio_position(y_emb) xy_pos = torch.concat([x, y_pos], dim=1) - y_example = y_pos[:, :, 0] * 0.0 - x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool() - y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64) + # 运行时判断对最后一个xy_pos还是整个xy_pos做self attention + multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(xy_pos)[1] + index_offset = torch.min(minus_one, multipled) + xy_pos = xy_pos[:, index_offset:] + + # 构造xy的attention mask + x_attn_mask = torch.zeros((x_seq_len, x_seq_len)).bool() + y_attn_mask = torch.ones((y_seq_len, y_seq_len)).to(torch.int64) y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum( - torch.ones_like( - y_example.transpose(0, 1), + torch.ones( + (y_seq_len, 1), dtype=torch.int64, ), dim=0, ) y_attn_mask = y_attn_mask > 0 - x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool() - y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool() - x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1) + x_y_pad = torch.ones((x_seq_len, y_seq_len)).to(torch.bool) + y_x_pad = torch.zeros((y_seq_len, x_seq_len)).to(torch.bool) + + x_attn_mask_pad = torch.cat([x_attn_mask, x_y_pad], dim=1) y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1) xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) - cache["k"] = ( - torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512))) - .unsqueeze(1) - .repeat(self.num_layers, 1, 1, 1) - ) - cache["v"] = ( - torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512))) - .unsqueeze(1) - .repeat(self.num_layers, 1, 1, 1) - ) + print("first iter xy_attn_mask shape:", xy_attn_mask.shape) + cache["k"] = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float) + cache["v"] = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float) + + print("first iter cache k shape:", cache["k"].shape, 'cache v shape:', cache["v"].shape) xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) logits = self.ar_predict_layer(xy_dec[:, -1]) @@ -216,7 +231,7 @@ class T2SStageDecoder(nn.Module): self.early_stop_num = early_stop_num self.num_layers = num_layers - def forward(self, y, k, v, y_emb, x_example, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None): + def forward(self, x, y, k, v, y_emb, x_example, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None): if top_k is None: top_k = torch.LongTensor([15]).to(device=y.device) if top_p is None: @@ -225,7 +240,8 @@ class T2SStageDecoder(nn.Module): repetition_penalty = torch.FloatTensor([1.0]).to(device=y.device) if temperature is None: temperature = torch.FloatTensor([1.0]).to(device=y.device) - + minus_one = torch.tensor([-1]).to(y.device).to(torch.int64) + cache = { "all_stage": self.num_layers, "k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)), @@ -235,22 +251,32 @@ class T2SStageDecoder(nn.Module): "stage": 0, } + # 运行时判断对最后一个y还是整个y做embedding,以正确应对首次和后续 + multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(y)[1] + index_offset = torch.min(minus_one, multipled) + y_to_emb = y[:, index_offset:] + # 对y输入进行embedding y_emb = torch.cat( [ cache["y_emb"], - self.ar_audio_embedding(y[:, -1:]), + self.ar_audio_embedding(y_to_emb), ], 1, ) cache["y_emb"] = y_emb y_pos = self.ar_audio_position(y_emb) + # 与x输入拼接做attention准备 + xy_pos = torch.concat([x, y_pos], dim=1) - xy_pos = y_pos[:, -1:] + # 运行时判断对最后一个xy_pos还是整个xy_pos做self attention + multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(xy_pos)[1] + index_offset = torch.min(minus_one, multipled) + xy_pos = xy_pos[:, index_offset:] - y_example = y_pos[:, :, 0] * 0.0 - - xy_attn_mask = torch.cat([x_example, y_example], dim=1) - xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool) + x_example_len = torch.onnx.operators.shape_as_tensor(x_example)[1] + y_example_len = torch.onnx.operators.shape_as_tensor(y_pos)[1] + xy_attn_mask = torch.zeros((1, x_example_len + y_example_len), dtype=torch.bool) + print('xy_attn_mask shape:', xy_attn_mask.shape) xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) logits = self.ar_predict_layer(xy_dec[:, -1]) @@ -336,11 +362,14 @@ class Text2SemanticDecoder(nn.Module): prefix_len = prompts.shape[1] x = self.onnx_encoder(x, bert_feature) - y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts, top_k=top_k, first_infer=torch.LongTensor([1])) + y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts, torch.empty((1,0,512)).to(torch.float), top_k=top_k, + first_infer=torch.LongTensor([1]), + x_seq_len=x.shape[1], y_seq_len=prompts.shape[1]) stop = False for idx in tqdm(range(1, 1500)): - enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k, first_infer=torch.LongTensor([0])) + enco = self.stage_decoder( torch.empty((1,0,512)).to(torch.float) ,y, k, v, y_emb, x_example, top_k=top_k, + first_infer=torch.LongTensor([0])) y, k, v, y_emb, logits, samples = enco if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: stop = True