diff --git a/GPT_SoVITS/AR/models/t2s_model_onnx.py b/GPT_SoVITS/AR/models/t2s_model_onnx.py index 203cbfc4..fd6578f2 100644 --- a/GPT_SoVITS/AR/models/t2s_model_onnx.py +++ b/GPT_SoVITS/AR/models/t2s_model_onnx.py @@ -145,10 +145,12 @@ class T2SStageDecoder(nn.Module): "y_emb": y_emb, "first_infer": first_infer, "stage": 0, + "x_seq_len": x_seq_len, + "y_seq_len": y_seq_len, } # 运行时判断对最后一个y还是整个y做embedding,以正确应对首次和后续 - multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(y)[1] + multipled = minus_one * first_infer * y_seq_len index_offset = torch.min(minus_one, multipled) y_to_emb = y[:, index_offset:] # 对y输入进行embedding @@ -165,7 +167,7 @@ class T2SStageDecoder(nn.Module): xy_pos = torch.concat([x, y_pos], dim=1) # 运行时判断对最后一个xy_pos还是整个xy_pos做self attention - multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(xy_pos)[1] + multipled = minus_one * first_infer * (x_seq_len + y_seq_len) # xy_pos = 1 or x_seq_len + y_seq_len index_offset = torch.min(minus_one, multipled) xy_pos = xy_pos[:, index_offset:] @@ -189,7 +191,7 @@ class T2SStageDecoder(nn.Module): xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) # 运行时判断attension mask使用最后一个还是整个 - multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(xy_attn_mask)[0] + multipled = minus_one * first_infer * (x_seq_len + y_seq_len) index_offset = torch.min(minus_one, multipled) xy_attn_mask = xy_attn_mask[index_offset:, :] @@ -268,20 +270,26 @@ class Text2SemanticDecoder(nn.Module): x = self.onnx_encoder(x, bert_feature) - init_k = torch.zeros((self.num_layers, (x.shape[1] + prompts.shape[1]), 1, 512), dtype=torch.float) - init_v = torch.zeros((self.num_layers, (x.shape[1] + prompts.shape[1]), 1, 512), dtype=torch.float) + x_seq_len = x.shape[1] + y_seq_len = prompts.shape[1] + + init_k = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float) + init_v = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float) + + empty_tensor = torch.empty((1,0,512)).to(torch.float) y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompts, init_k, init_v, - torch.empty((1,0,512)).to(torch.float), top_k=top_k, + empty_tensor, top_k=top_k, first_infer=torch.LongTensor([1]), - x_seq_len=x.shape[1], y_seq_len=prompts.shape[1]) + x_seq_len=x_seq_len, y_seq_len=y_seq_len) stop = False for idx in tqdm(range(1, 1500)): k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)) v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)) - enco = self.stage_decoder( torch.empty((1,0,512)).to(torch.float) ,y, k, v, y_emb, top_k=top_k, - first_infer=torch.LongTensor([0]), x_seq_len=x.shape[1], y_seq_len=y.shape[1]) + y_seq_len = y.shape[1] + enco = self.stage_decoder(empty_tensor, y, k, v, y_emb, top_k=top_k, + first_infer=torch.LongTensor([0]), x_seq_len=x_seq_len, y_seq_len=y_seq_len) y, k, v, y_emb, logits, samples = enco if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: stop = True diff --git a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py index 04ad5d87..83c92b0c 100644 --- a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py +++ b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py @@ -57,7 +57,7 @@ def multi_head_attention_forward_patched( cache_v = cache["v"][cache["stage"]] # Magic to get an index of either -1 or -N according to if first_infer_mask is set minus_one = torch.tensor([-1]).to(k.device).to(torch.int64) - multipled = minus_one * first_infer_mask * torch.onnx.operators.shape_as_tensor(query)[0] + multipled = minus_one * first_infer_mask * (cache['x_seq_len'] + cache['y_seq_len']) index_offset = torch.min(minus_one, multipled) cache_k[index_offset :, :, :] = k cache_v[index_offset :, :, :] = v