feat:update kv cache to [len, head, dim] to allow linear size increasement

This commit is contained in:
zpeng11 2025-08-26 17:01:41 -04:00
parent fa84e262ae
commit 3e63595f0e
4 changed files with 16 additions and 18 deletions

View File

@ -273,8 +273,8 @@ class Text2SemanticDecoder(nn.Module):
x_seq_len = x.shape[1]
y_seq_len = prompts.shape[1]
init_k = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
init_v = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
init_k = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
init_v = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
empty_tensor = torch.empty((1,0,512)).to(torch.float)

View File

@ -51,20 +51,18 @@ def multi_head_attention_forward_patched(
# 使用动态形状推断来统一处理kv cache首步和后续步骤形状差异
# # k,v : [N, 1, 512] at first time, [1, 1, 512] afterwards
# # cache_k, cache_v : [1, N, 1, 512] size increasement is prepared outside
first_infer_mask = cache["first_infer"]
cache_k = cache["k"][cache["stage"]]
cache_v = cache["v"][cache["stage"]]
# # cache_k, cache_v : [N, 1, 512] for one head, N size increasement is prepared outside
# cache["k"][:, cache["stage"]:cache["stage"]+1, :]
# cache["v"][:, cache["stage"]:cache["stage"]+1, :]
# Magic to get an index of either -1 or -N according to if first_infer_mask is set
minus_one = torch.tensor([-1]).to(k.device).to(torch.int64)
multipled = minus_one * first_infer_mask * (cache['x_seq_len'] + cache['y_seq_len'])
multipled = minus_one * cache["first_infer"] * (cache['x_seq_len'] + cache['y_seq_len'])
index_offset = torch.min(minus_one, multipled)
cache_k[index_offset :, :, :] = k
cache_v[index_offset :, :, :] = v
cache["k"][cache["stage"]] = cache_k
cache["v"][cache["stage"]] = cache_v
k = cache_k
v = cache_v
# 首次时 index 为 -N后续index 为 -1
cache["k"][index_offset:, cache["stage"]:cache["stage"]+1, :] = k
cache["v"][index_offset:, cache["stage"]:cache["stage"]+1, :] = v
k = cache["k"][:, cache["stage"]:cache["stage"]+1, :]
v = cache["v"][:, cache["stage"]:cache["stage"]+1, :]
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]

View File

@ -157,7 +157,7 @@ def main():
parser = argparse.ArgumentParser(description="Export BERT model to ONNX")
parser.add_argument("--model_name", type=str, default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
help="Pretrained BERT model name")
parser.add_argument("--output_dir", type=str, default="playground/bert",
parser.add_argument("--output_dir", type=str, default="playground/chinese-roberta-wwm-ext-large",
help="Output directory path")
parser.add_argument("--max_seq_length", type=int, default=512,
help="Maximum sequence length")

View File

@ -124,8 +124,8 @@ class T2SInitStage(nn.Module):
x_seq_len = torch.onnx.operators.shape_as_tensor(x)[1]
y_seq_len = torch.onnx.operators.shape_as_tensor(prompt)[1]
init_k = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
init_v = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
init_k = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
init_v = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
return x, prompt, init_k, init_v, x_seq_len, y_seq_len
@ -210,8 +210,8 @@ class T2SModel(nn.Module):
dynamic_axes={
"ix": {1: "ix_length"},
"iy": {1: "iy_length"},
"ik": {1: "ik_length"},
"iv": {1: "iv_length"},
"ik": {0: "ik_length"},
"iv": {0: "iv_length"},
"iy_emb": {1: "iy_emb_length"},
},
verbose=False,