diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 56dca94..74dca9d 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -39,7 +39,7 @@ default_config = { "EOS": 1024, } -# @torch.jit.script +@torch.jit.script # Efficient implementation equivalent to the following: def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor: B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2) @@ -82,7 +82,7 @@ class T2SMLP: return x -# @torch.jit.script +@torch.jit.script class T2SBlock: def __init__( self, @@ -114,6 +114,8 @@ class T2SBlock: self.norm_b2 = norm_b2 self.norm_eps2 = norm_eps2 + self.false = torch.tensor(False, dtype=torch.bool) + @torch.jit.ignore def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]): if padding_mask is None: @@ -150,7 +152,9 @@ class T2SBlock: if padding_mask is not None: for i in range(batch_size): # mask = padding_mask[i,:,0] - idx = torch.where(padding_mask[i,:,0]==False)[0] + if self.false.device!= padding_mask.device: + self.false = self.false.to(padding_mask.device) + idx = torch.where(padding_mask[i,:,0]==self.false)[0] x_item = x[i,idx,:].unsqueeze(0) attn_item = attn[i,idx,:].unsqueeze(0) x_item = x_item + attn_item @@ -218,7 +222,7 @@ class T2SBlock: return x, k_cache, v_cache -# @torch.jit.script +@torch.jit.script class T2STransformer: def __init__(self, num_blocks : int, blocks: List[T2SBlock]): self.num_blocks : int = num_blocks