diff --git a/GPT_SoVITS/AR/models/t2s_model_onnx.py b/GPT_SoVITS/AR/models/t2s_model_onnx.py index fd6578f2..c8a3d876 100644 --- a/GPT_SoVITS/AR/models/t2s_model_onnx.py +++ b/GPT_SoVITS/AR/models/t2s_model_onnx.py @@ -273,8 +273,8 @@ class Text2SemanticDecoder(nn.Module): x_seq_len = x.shape[1] y_seq_len = prompts.shape[1] - init_k = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float) - init_v = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float) + init_k = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float) + init_v = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float) empty_tensor = torch.empty((1,0,512)).to(torch.float) diff --git a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py index 83c92b0c..f8461f7d 100644 --- a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py +++ b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py @@ -51,20 +51,18 @@ def multi_head_attention_forward_patched( # 使用动态形状推断来统一处理kv cache首步和后续步骤形状差异 # # k,v : [N, 1, 512] at first time, [1, 1, 512] afterwards - # # cache_k, cache_v : [1, N, 1, 512] size increasement is prepared outside - first_infer_mask = cache["first_infer"] - cache_k = cache["k"][cache["stage"]] - cache_v = cache["v"][cache["stage"]] + # # cache_k, cache_v : [N, 1, 512] for one head, N size increasement is prepared outside + # cache["k"][:, cache["stage"]:cache["stage"]+1, :] + # cache["v"][:, cache["stage"]:cache["stage"]+1, :] # Magic to get an index of either -1 or -N according to if first_infer_mask is set minus_one = torch.tensor([-1]).to(k.device).to(torch.int64) - multipled = minus_one * first_infer_mask * (cache['x_seq_len'] + cache['y_seq_len']) + multipled = minus_one * cache["first_infer"] * (cache['x_seq_len'] + cache['y_seq_len']) index_offset = torch.min(minus_one, multipled) - cache_k[index_offset :, :, :] = k - cache_v[index_offset :, :, :] = v - cache["k"][cache["stage"]] = cache_k - cache["v"][cache["stage"]] = cache_v - k = cache_k - v = cache_v + # 首次时 index 为 -N,后续index 为 -1 + cache["k"][index_offset:, cache["stage"]:cache["stage"]+1, :] = k + cache["v"][index_offset:, cache["stage"]:cache["stage"]+1, :] = v + k = cache["k"][:, cache["stage"]:cache["stage"]+1, :] + v = cache["v"][:, cache["stage"]:cache["stage"]+1, :] cache["stage"] = (cache["stage"] + 1) % cache["all_stage"] diff --git a/GPT_SoVITS/export_roberta_onnx.py b/GPT_SoVITS/export_roberta_onnx.py index 860d1af2..cea4c4ee 100644 --- a/GPT_SoVITS/export_roberta_onnx.py +++ b/GPT_SoVITS/export_roberta_onnx.py @@ -157,7 +157,7 @@ def main(): parser = argparse.ArgumentParser(description="Export BERT model to ONNX") parser.add_argument("--model_name", type=str, default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", help="Pretrained BERT model name") - parser.add_argument("--output_dir", type=str, default="playground/bert", + parser.add_argument("--output_dir", type=str, default="playground/chinese-roberta-wwm-ext-large", help="Output directory path") parser.add_argument("--max_seq_length", type=int, default=512, help="Maximum sequence length") diff --git a/GPT_SoVITS/onnx_export_v1v2.py b/GPT_SoVITS/onnx_export_v1v2.py index 172eba1d..fad6e09f 100644 --- a/GPT_SoVITS/onnx_export_v1v2.py +++ b/GPT_SoVITS/onnx_export_v1v2.py @@ -124,8 +124,8 @@ class T2SInitStage(nn.Module): x_seq_len = torch.onnx.operators.shape_as_tensor(x)[1] y_seq_len = torch.onnx.operators.shape_as_tensor(prompt)[1] - init_k = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float) - init_v = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float) + init_k = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float) + init_v = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float) return x, prompt, init_k, init_v, x_seq_len, y_seq_len @@ -210,8 +210,8 @@ class T2SModel(nn.Module): dynamic_axes={ "ix": {1: "ix_length"}, "iy": {1: "iy_length"}, - "ik": {1: "ik_length"}, - "iv": {1: "iv_length"}, + "ik": {0: "ik_length"}, + "iv": {0: "iv_length"}, "iy_emb": {1: "iy_emb_length"}, }, verbose=False,