From 0b2e3760c268b0cbafb0f3e00207f776b80ac7aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Thu, 8 Feb 2024 21:39:56 +0800 Subject: [PATCH] Add files via upload --- GPT_SoVITS/module/attentions_onnx.py | 47 +++++++++++----------------- 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/GPT_SoVITS/module/attentions_onnx.py b/GPT_SoVITS/module/attentions_onnx.py index df0ae82..bc63a06 100644 --- a/GPT_SoVITS/module/attentions_onnx.py +++ b/GPT_SoVITS/module/attentions_onnx.py @@ -188,38 +188,27 @@ class MultiHeadAttention(nn.Module): query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) - scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) - rel_logits = self._matmul_with_relative_keys( - query / math.sqrt(self.k_channels), key_relative_embeddings - ) + rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) scores_local = self._relative_position_to_absolute_position(rel_logits) scores = scores + scores_local + if mask is not None: scores = scores.masked_fill(mask == 0, -1e4) - if self.block_length is not None: - block_mask = ( - torch.ones_like(scores) - .triu(-self.block_length) - .tril(self.block_length) - ) - scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) p_attn = self.drop(p_attn) output = torch.matmul(p_attn, value) + if self.window_size is not None: relative_weights = self._absolute_position_to_relative_position(p_attn) - value_relative_embeddings = self._get_relative_embeddings( - self.emb_rel_v, t_s - ) - output = output + self._matmul_with_relative_values( - relative_weights, value_relative_embeddings - ) - output = ( - output.transpose(2, 3).contiguous().view(b, d, -1) - ) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + + output = (output.transpose(2, 3).contiguous().view(b, d, -1)) return output, p_attn def _matmul_with_relative_values(self, x, y): @@ -243,16 +232,16 @@ class MultiHeadAttention(nn.Module): def _get_relative_embeddings(self, relative_embeddings, length): max_relative_position = 2 * self.window_size + 1 # Pad first before slice to avoid using cond ops. - pad_length = max(length - (self.window_size + 1), 0) - slice_start_position = max((self.window_size + 1) - length, 0) + pad_l = torch.zeros((1), dtype = torch.int64) + length - (self.window_size + 1) + pad_s = torch.zeros((1), dtype = torch.int64) + (self.window_size + 1) - length + pad_length = torch.max(pad_l, other=torch.zeros((1), dtype = torch.int64)) + slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype = torch.int64)) + slice_end_position = slice_start_position + 2 * length - 1 - if pad_length > 0: - padded_relative_embeddings = F.pad( - relative_embeddings, - commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), - ) - else: - padded_relative_embeddings = relative_embeddings + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) used_relative_embeddings = padded_relative_embeddings[ :, slice_start_position:slice_end_position ]