feat:update kv cache to [len, head, dim] to allow linear size increasement

This commit is contained in:
zpeng11 2025-08-26 17:01:41 -04:00
parent fa84e262ae
commit 3e63595f0e
4 changed files with 16 additions and 18 deletions

View File

@ -273,8 +273,8 @@ class Text2SemanticDecoder(nn.Module):
x_seq_len = x.shape[1] x_seq_len = x.shape[1]
y_seq_len = prompts.shape[1] y_seq_len = prompts.shape[1]
init_k = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float) init_k = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
init_v = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float) init_v = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
empty_tensor = torch.empty((1,0,512)).to(torch.float) empty_tensor = torch.empty((1,0,512)).to(torch.float)

View File

@ -51,20 +51,18 @@ def multi_head_attention_forward_patched(
# 使用动态形状推断来统一处理kv cache首步和后续步骤形状差异 # 使用动态形状推断来统一处理kv cache首步和后续步骤形状差异
# # k,v : [N, 1, 512] at first time, [1, 1, 512] afterwards # # k,v : [N, 1, 512] at first time, [1, 1, 512] afterwards
# # cache_k, cache_v : [1, N, 1, 512] size increasement is prepared outside # # cache_k, cache_v : [N, 1, 512] for one head, N size increasement is prepared outside
first_infer_mask = cache["first_infer"] # cache["k"][:, cache["stage"]:cache["stage"]+1, :]
cache_k = cache["k"][cache["stage"]] # cache["v"][:, cache["stage"]:cache["stage"]+1, :]
cache_v = cache["v"][cache["stage"]]
# Magic to get an index of either -1 or -N according to if first_infer_mask is set # Magic to get an index of either -1 or -N according to if first_infer_mask is set
minus_one = torch.tensor([-1]).to(k.device).to(torch.int64) minus_one = torch.tensor([-1]).to(k.device).to(torch.int64)
multipled = minus_one * first_infer_mask * (cache['x_seq_len'] + cache['y_seq_len']) multipled = minus_one * cache["first_infer"] * (cache['x_seq_len'] + cache['y_seq_len'])
index_offset = torch.min(minus_one, multipled) index_offset = torch.min(minus_one, multipled)
cache_k[index_offset :, :, :] = k # 首次时 index 为 -N后续index 为 -1
cache_v[index_offset :, :, :] = v cache["k"][index_offset:, cache["stage"]:cache["stage"]+1, :] = k
cache["k"][cache["stage"]] = cache_k cache["v"][index_offset:, cache["stage"]:cache["stage"]+1, :] = v
cache["v"][cache["stage"]] = cache_v k = cache["k"][:, cache["stage"]:cache["stage"]+1, :]
k = cache_k v = cache["v"][:, cache["stage"]:cache["stage"]+1, :]
v = cache_v
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"] cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]

View File

@ -157,7 +157,7 @@ def main():
parser = argparse.ArgumentParser(description="Export BERT model to ONNX") parser = argparse.ArgumentParser(description="Export BERT model to ONNX")
parser.add_argument("--model_name", type=str, default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", parser.add_argument("--model_name", type=str, default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
help="Pretrained BERT model name") help="Pretrained BERT model name")
parser.add_argument("--output_dir", type=str, default="playground/bert", parser.add_argument("--output_dir", type=str, default="playground/chinese-roberta-wwm-ext-large",
help="Output directory path") help="Output directory path")
parser.add_argument("--max_seq_length", type=int, default=512, parser.add_argument("--max_seq_length", type=int, default=512,
help="Maximum sequence length") help="Maximum sequence length")

View File

@ -124,8 +124,8 @@ class T2SInitStage(nn.Module):
x_seq_len = torch.onnx.operators.shape_as_tensor(x)[1] x_seq_len = torch.onnx.operators.shape_as_tensor(x)[1]
y_seq_len = torch.onnx.operators.shape_as_tensor(prompt)[1] y_seq_len = torch.onnx.operators.shape_as_tensor(prompt)[1]
init_k = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float) init_k = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
init_v = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float) init_v = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
return x, prompt, init_k, init_v, x_seq_len, y_seq_len return x, prompt, init_k, init_v, x_seq_len, y_seq_len
@ -210,8 +210,8 @@ class T2SModel(nn.Module):
dynamic_axes={ dynamic_axes={
"ix": {1: "ix_length"}, "ix": {1: "ix_length"},
"iy": {1: "iy_length"}, "iy": {1: "iy_length"},
"ik": {1: "ik_length"}, "ik": {0: "ik_length"},
"iv": {1: "iv_length"}, "iv": {0: "iv_length"},
"iy_emb": {1: "iy_emb_length"}, "iy_emb": {1: "iy_emb_length"},
}, },
verbose=False, verbose=False,