diff --git a/GPT_SoVITS/module/attentions_onnx.py b/GPT_SoVITS/module/attentions_onnx.py index df0ae82..bc63a06 100644 --- a/GPT_SoVITS/module/attentions_onnx.py +++ b/GPT_SoVITS/module/attentions_onnx.py @@ -188,38 +188,27 @@ class MultiHeadAttention(nn.Module): query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) - scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) - rel_logits = self._matmul_with_relative_keys( - query / math.sqrt(self.k_channels), key_relative_embeddings - ) + rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) scores_local = self._relative_position_to_absolute_position(rel_logits) scores = scores + scores_local + if mask is not None: scores = scores.masked_fill(mask == 0, -1e4) - if self.block_length is not None: - block_mask = ( - torch.ones_like(scores) - .triu(-self.block_length) - .tril(self.block_length) - ) - scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) p_attn = self.drop(p_attn) output = torch.matmul(p_attn, value) + if self.window_size is not None: relative_weights = self._absolute_position_to_relative_position(p_attn) - value_relative_embeddings = self._get_relative_embeddings( - self.emb_rel_v, t_s - ) - output = output + self._matmul_with_relative_values( - relative_weights, value_relative_embeddings - ) - output = ( - output.transpose(2, 3).contiguous().view(b, d, -1) - ) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + + output = (output.transpose(2, 3).contiguous().view(b, d, -1)) return output, p_attn def _matmul_with_relative_values(self, x, y): @@ -243,16 +232,16 @@ class MultiHeadAttention(nn.Module): def _get_relative_embeddings(self, relative_embeddings, length): max_relative_position = 2 * self.window_size + 1 # Pad first before slice to avoid using cond ops. - pad_length = max(length - (self.window_size + 1), 0) - slice_start_position = max((self.window_size + 1) - length, 0) + pad_l = torch.zeros((1), dtype = torch.int64) + length - (self.window_size + 1) + pad_s = torch.zeros((1), dtype = torch.int64) + (self.window_size + 1) - length + pad_length = torch.max(pad_l, other=torch.zeros((1), dtype = torch.int64)) + slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype = torch.int64)) + slice_end_position = slice_start_position + 2 * length - 1 - if pad_length > 0: - padded_relative_embeddings = F.pad( - relative_embeddings, - commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), - ) - else: - padded_relative_embeddings = relative_embeddings + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) used_relative_embeddings = padded_relative_embeddings[ :, slice_start_position:slice_end_position ]