mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 00:30:15 +08:00
feat:update kv cache to [len, head, dim] to allow linear size increasement
This commit is contained in:
parent
fa84e262ae
commit
3e63595f0e
@ -273,8 +273,8 @@ class Text2SemanticDecoder(nn.Module):
|
||||
x_seq_len = x.shape[1]
|
||||
y_seq_len = prompts.shape[1]
|
||||
|
||||
init_k = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
|
||||
init_v = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
|
||||
init_k = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
|
||||
init_v = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
|
||||
|
||||
empty_tensor = torch.empty((1,0,512)).to(torch.float)
|
||||
|
||||
|
@ -51,20 +51,18 @@ def multi_head_attention_forward_patched(
|
||||
|
||||
# 使用动态形状推断来统一处理kv cache首步和后续步骤形状差异
|
||||
# # k,v : [N, 1, 512] at first time, [1, 1, 512] afterwards
|
||||
# # cache_k, cache_v : [1, N, 1, 512] size increasement is prepared outside
|
||||
first_infer_mask = cache["first_infer"]
|
||||
cache_k = cache["k"][cache["stage"]]
|
||||
cache_v = cache["v"][cache["stage"]]
|
||||
# # cache_k, cache_v : [N, 1, 512] for one head, N size increasement is prepared outside
|
||||
# cache["k"][:, cache["stage"]:cache["stage"]+1, :]
|
||||
# cache["v"][:, cache["stage"]:cache["stage"]+1, :]
|
||||
# Magic to get an index of either -1 or -N according to if first_infer_mask is set
|
||||
minus_one = torch.tensor([-1]).to(k.device).to(torch.int64)
|
||||
multipled = minus_one * first_infer_mask * (cache['x_seq_len'] + cache['y_seq_len'])
|
||||
multipled = minus_one * cache["first_infer"] * (cache['x_seq_len'] + cache['y_seq_len'])
|
||||
index_offset = torch.min(minus_one, multipled)
|
||||
cache_k[index_offset :, :, :] = k
|
||||
cache_v[index_offset :, :, :] = v
|
||||
cache["k"][cache["stage"]] = cache_k
|
||||
cache["v"][cache["stage"]] = cache_v
|
||||
k = cache_k
|
||||
v = cache_v
|
||||
# 首次时 index 为 -N,后续index 为 -1
|
||||
cache["k"][index_offset:, cache["stage"]:cache["stage"]+1, :] = k
|
||||
cache["v"][index_offset:, cache["stage"]:cache["stage"]+1, :] = v
|
||||
k = cache["k"][:, cache["stage"]:cache["stage"]+1, :]
|
||||
v = cache["v"][:, cache["stage"]:cache["stage"]+1, :]
|
||||
|
||||
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
|
||||
|
||||
|
@ -157,7 +157,7 @@ def main():
|
||||
parser = argparse.ArgumentParser(description="Export BERT model to ONNX")
|
||||
parser.add_argument("--model_name", type=str, default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
help="Pretrained BERT model name")
|
||||
parser.add_argument("--output_dir", type=str, default="playground/bert",
|
||||
parser.add_argument("--output_dir", type=str, default="playground/chinese-roberta-wwm-ext-large",
|
||||
help="Output directory path")
|
||||
parser.add_argument("--max_seq_length", type=int, default=512,
|
||||
help="Maximum sequence length")
|
||||
|
@ -124,8 +124,8 @@ class T2SInitStage(nn.Module):
|
||||
x_seq_len = torch.onnx.operators.shape_as_tensor(x)[1]
|
||||
y_seq_len = torch.onnx.operators.shape_as_tensor(prompt)[1]
|
||||
|
||||
init_k = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
|
||||
init_v = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
|
||||
init_k = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
|
||||
init_v = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
|
||||
|
||||
return x, prompt, init_k, init_v, x_seq_len, y_seq_len
|
||||
|
||||
@ -210,8 +210,8 @@ class T2SModel(nn.Module):
|
||||
dynamic_axes={
|
||||
"ix": {1: "ix_length"},
|
||||
"iy": {1: "iy_length"},
|
||||
"ik": {1: "ik_length"},
|
||||
"iv": {1: "iv_length"},
|
||||
"ik": {0: "ik_length"},
|
||||
"iv": {0: "iv_length"},
|
||||
"iy_emb": {1: "iy_emb_length"},
|
||||
},
|
||||
verbose=False,
|
||||
|
Loading…
x
Reference in New Issue
Block a user