modified: GPT_SoVITS/AR/models/t2s_model.py

This commit is contained in:
chasonjiang 2024-03-09 20:23:55 +08:00
parent 3b9259b0a1
commit be49e32505

View File

@ -115,7 +115,7 @@ class T2SBlock:
)
return x, k_cache, v_cache
def decode_next_token(self, x, k_cache, v_cache, attn_mask : torch.Tensor):
def decode_next_token(self, x, k_cache, v_cache):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1)