diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 23da380..4908c59 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -115,7 +115,7 @@ class T2SBlock: ) return x, k_cache, v_cache - def decode_next_token(self, x, k_cache, v_cache, attn_mask : torch.Tensor): + def decode_next_token(self, x, k_cache, v_cache): q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) k_cache = torch.cat([k_cache, k], dim=1)