Add files via upload

This commit is contained in:
Ναρουσέ·μ·γιουμεμί·Χινακάννα 2024-02-08 21:39:56 +08:00 committed by GitHub
parent 08aed05796
commit 0b2e3760c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -188,38 +188,27 @@ class MultiHeadAttention(nn.Module):
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None: if self.window_size is not None:
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys( rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
query / math.sqrt(self.k_channels), key_relative_embeddings
)
scores_local = self._relative_position_to_absolute_position(rel_logits) scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local scores = scores + scores_local
if mask is not None: if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4) scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
block_mask = (
torch.ones_like(scores)
.triu(-self.block_length)
.tril(self.block_length)
)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) p_attn = F.softmax(scores, dim=-1)
p_attn = self.drop(p_attn) p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value) output = torch.matmul(p_attn, value)
if self.window_size is not None: if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn) relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings( value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
self.emb_rel_v, t_s output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
)
output = output + self._matmul_with_relative_values( output = (output.transpose(2, 3).contiguous().view(b, d, -1))
relative_weights, value_relative_embeddings
)
output = (
output.transpose(2, 3).contiguous().view(b, d, -1)
)
return output, p_attn return output, p_attn
def _matmul_with_relative_values(self, x, y): def _matmul_with_relative_values(self, x, y):
@ -243,16 +232,16 @@ class MultiHeadAttention(nn.Module):
def _get_relative_embeddings(self, relative_embeddings, length): def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1 max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops. # Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0) pad_l = torch.zeros((1), dtype = torch.int64) + length - (self.window_size + 1)
slice_start_position = max((self.window_size + 1) - length, 0) pad_s = torch.zeros((1), dtype = torch.int64) + (self.window_size + 1) - length
pad_length = torch.max(pad_l, other=torch.zeros((1), dtype = torch.int64))
slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype = torch.int64))
slice_end_position = slice_start_position + 2 * length - 1 slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0: padded_relative_embeddings = F.pad(
padded_relative_embeddings = F.pad( relative_embeddings,
relative_embeddings, commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), )
)
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[ used_relative_embeddings = padded_relative_embeddings[
:, slice_start_position:slice_end_position :, slice_start_position:slice_end_position
] ]