modified: GPT_SoVITS/AR/models/t2s_model.py

This commit is contained in:
chasonjiang 2024-03-09 20:21:11 +08:00
parent 4096a17e7e
commit 3b9259b0a1

View File

@ -166,10 +166,10 @@ class T2STransformer:
return x, k_cache, v_cache
def decode_next_token(
self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor], attn_mask : torch.Tensor
self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]
):
for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask)
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
return x, k_cache, v_cache
@ -554,7 +554,7 @@ class Text2SemanticDecoder(nn.Module):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask)
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]