mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
modified: GPT_SoVITS/AR/models/t2s_model.py
This commit is contained in:
parent
4096a17e7e
commit
3b9259b0a1
@ -166,10 +166,10 @@ class T2STransformer:
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(
|
||||
self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor], attn_mask : torch.Tensor
|
||||
self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]
|
||||
):
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask)
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
|
||||
return x, k_cache, v_cache
|
||||
|
||||
|
||||
@ -554,7 +554,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if idx == 0:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask)
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||
|
||||
logits = self.ar_predict_layer(
|
||||
xy_dec[:, -1]
|
||||
|
Loading…
x
Reference in New Issue
Block a user