mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 17:10:02 +08:00
531 lines
16 KiB
Python
531 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import math
|
|
from abc import ABC, abstractmethod
|
|
from typing import MutableSequence, cast
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
from .structs_mlx import KVCache, KVCacheProtocol, KVCacheQ, T2SDecoderProtocol, T2SSessionMLX
|
|
|
|
Array = mx.array
|
|
|
|
|
|
class TokenEmbedding(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embedding_dim: int,
|
|
vocab_size: int,
|
|
):
|
|
super().__init__()
|
|
|
|
self.vocab_size = vocab_size
|
|
self.embedding_dim = embedding_dim
|
|
|
|
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
|
|
|
|
@property
|
|
def weight(self):
|
|
return self.word_embeddings.weight
|
|
|
|
def embedding(self, index: int):
|
|
return self.word_embeddings.weight[index : index + 1]
|
|
|
|
def __call__(self, x: Array):
|
|
x = self.word_embeddings(x)
|
|
return x
|
|
|
|
|
|
class SinePositionalEmbedding(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embedding_dim: int,
|
|
scale: bool = False,
|
|
max_batch_size: int = 10,
|
|
max_seq_len: int = 2000,
|
|
):
|
|
super().__init__()
|
|
self.embedding_dim = embedding_dim
|
|
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
|
self.alpha = mx.ones(1)
|
|
self.max_batch_size = max_batch_size
|
|
self.max_seq_len = max_seq_len
|
|
|
|
self.reverse = False
|
|
self._pe = mx.zeros((max_batch_size, max_seq_len, embedding_dim))
|
|
self.compute_pe()
|
|
|
|
def compute_pe(self):
|
|
"""Reset the positional encodings."""
|
|
|
|
if self.reverse:
|
|
position = mx.expand_dims(mx.arange(self.max_seq_len - 1, -1, -1.0), axis=1)
|
|
else:
|
|
position = mx.expand_dims(mx.arange(self.max_seq_len), axis=1)
|
|
div_term = mx.exp(
|
|
mx.arange(
|
|
0,
|
|
self.embedding_dim,
|
|
2,
|
|
)
|
|
* -(math.log(10000.0) / self.embedding_dim)
|
|
)
|
|
pe = self._pe
|
|
pe[:, :, 0::2] = mx.sin(position * div_term)
|
|
pe[:, :, 1::2] = mx.cos(position * div_term)
|
|
|
|
def __call__(self, input_pos: Array, x: Array):
|
|
"""
|
|
Args:
|
|
input_pos (Array): [batch_size, ]
|
|
x (Array): [batch_size, 1, embed_dim]
|
|
|
|
Returns:
|
|
embedded_x (Array): [batch_size, 1, embed_dim]
|
|
"""
|
|
|
|
batch_size = cast(tuple[int, ...], x.shape)[0]
|
|
pe_values = self._pe[mx.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
|
|
|
|
return x * self.x_scale + self.alpha * mx.expand_dims(pe_values, 1) # (batch_size, 1, embed_dim)
|
|
|
|
def prefill(self, x: Array):
|
|
"""
|
|
Args:
|
|
x (Array): [batch_size, seq_len, embed_dim]
|
|
|
|
Returns:
|
|
embedded_x (Array): [batch_size, seq_len, embed_dim]
|
|
"""
|
|
pe_values = self._pe[:, : cast(tuple[int, ...], x.shape)[-2]]
|
|
return x * self.x_scale + self.alpha * pe_values
|
|
|
|
|
|
class KVCacheHND(KVCacheProtocol):
|
|
@staticmethod
|
|
def empty(kv_cache):
|
|
assert len(kv_cache) == 2
|
|
k_cache, v_cache = kv_cache
|
|
|
|
k_cache[:] = 0
|
|
v_cache[:] = 0
|
|
|
|
@staticmethod
|
|
def update_cache(input_pos, k_val, v_val, kv_cache, cache_idx):
|
|
# input_pos: [B, ], k_val: [B, H, 1, D]
|
|
assert len(kv_cache) == 2
|
|
k_out, v_out = kv_cache
|
|
ip0 = input_pos - 1
|
|
|
|
k_out[cache_idx, :, ip0, None] = k_val
|
|
v_out[cache_idx, :, ip0, None] = v_val
|
|
|
|
return k_out, v_out
|
|
|
|
@staticmethod
|
|
def prefill_kv(k_val, v_val, kv_cache):
|
|
# k_val: [B, S, H, D]
|
|
assert len(kv_cache) == 2
|
|
k_cache, v_cache = kv_cache
|
|
|
|
k_cache[..., : cast(tuple[int, ...], k_val.shape)[1], :] = k_val.swapaxes(1, 2)
|
|
v_cache[..., : cast(tuple[int, ...], v_val.shape)[1], :] = v_val.swapaxes(1, 2)
|
|
|
|
@staticmethod
|
|
def init_cache(batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype) -> KVCache:
|
|
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
|
|
|
|
return (mx.zeros(cache_shape, dtype=dtype), mx.zeros(cache_shape, dtype=dtype))
|
|
|
|
|
|
class KVCacheHNDQuantized(KVCacheProtocol):
|
|
@staticmethod
|
|
def _el_per_int(bits: int) -> int:
|
|
return 32 // bits
|
|
|
|
@staticmethod
|
|
def _packed_dim(head_dim: int, bits: int = 8) -> int:
|
|
el_per_int = KVCacheHNDQuantized._el_per_int(bits)
|
|
if head_dim % el_per_int != 0:
|
|
raise ValueError(f"{head_dim=} is not divisible by {el_per_int=} ({bits=})")
|
|
return head_dim // el_per_int
|
|
|
|
@staticmethod
|
|
def _group_count(head_dim: int, group_size: int = 32) -> int:
|
|
assert group_size in {32, 64, 128}
|
|
if head_dim % group_size != 0:
|
|
raise ValueError(f"{head_dim} is not divisible by {group_size=}")
|
|
return head_dim // group_size
|
|
|
|
@staticmethod
|
|
def empty(kv_cache) -> None:
|
|
assert len(kv_cache) == 3
|
|
(k_q, k_s, k_b), (v_q, v_s, v_b), (_, __) = kv_cache
|
|
|
|
k_q[:] = 0
|
|
k_s[:] = 0
|
|
k_b[:] = 0
|
|
v_q[:] = 0
|
|
v_s[:] = 0
|
|
v_b[:] = 0
|
|
|
|
@staticmethod
|
|
def update_cache(
|
|
input_pos,
|
|
k_val,
|
|
v_val,
|
|
kv_cache,
|
|
cache_idx,
|
|
):
|
|
# input_pos: [B, ], k_val: [B, H, 1, D]
|
|
|
|
assert len(kv_cache) == 3
|
|
(k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
|
|
|
|
k_q, k_s, k_b = mx.quantize(k_val, group_size=group_size, bits=bits)
|
|
v_q, v_s, v_b = mx.quantize(v_val, group_size=group_size, bits=bits)
|
|
|
|
ip0 = input_pos - 1
|
|
|
|
k_q_out[cache_idx, :, ip0, None] = k_q
|
|
k_s_out[cache_idx, :, ip0, None] = k_s
|
|
k_b_out[cache_idx, :, ip0, None] = k_b
|
|
|
|
v_q_out[cache_idx, :, ip0, None] = v_q
|
|
v_s_out[cache_idx, :, ip0, None] = v_s
|
|
v_b_out[cache_idx, :, ip0, None] = v_b
|
|
|
|
return (k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits)
|
|
|
|
@staticmethod
|
|
def prefill_kv(
|
|
k_val,
|
|
v_val,
|
|
kv_cache,
|
|
) -> None:
|
|
assert len(kv_cache) == 3
|
|
(k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
|
|
|
|
S = cast(tuple[int, ...], k_val.shape)[1]
|
|
|
|
k_sw = k_val.swapaxes(1, 2)
|
|
v_sw = v_val.swapaxes(1, 2)
|
|
|
|
k_q, k_s, k_b = mx.quantize(k_sw, group_size=group_size, bits=bits)
|
|
v_q, v_s, v_b = mx.quantize(v_sw, group_size=group_size, bits=bits)
|
|
|
|
k_q_out[..., :S, :] = k_q
|
|
k_s_out[..., :S, :] = k_s
|
|
k_b_out[..., :S, :] = k_b
|
|
|
|
v_q_out[..., :S, :] = v_q
|
|
v_s_out[..., :S, :] = v_s
|
|
v_b_out[..., :S, :] = v_b
|
|
|
|
@staticmethod
|
|
def init_cache(
|
|
batch_size: int,
|
|
max_seq_length: int,
|
|
n_heads: int,
|
|
head_dim: int,
|
|
dtype: mx.Dtype,
|
|
*,
|
|
group_size: int = 32,
|
|
bits: int = 8,
|
|
) -> KVCacheQ:
|
|
packed_dim = KVCacheHNDQuantized._packed_dim(head_dim, bits=bits)
|
|
group_cnt = KVCacheHNDQuantized._group_count(head_dim, group_size=group_size)
|
|
|
|
packed_shape = (batch_size, n_heads, max_seq_length, packed_dim)
|
|
group_shape = (batch_size, n_heads, max_seq_length, group_cnt)
|
|
|
|
k_q = mx.zeros(packed_shape, dtype=mx.uint32)
|
|
k_s = mx.zeros(group_shape, dtype=dtype)
|
|
k_b = mx.zeros(group_shape, dtype=dtype)
|
|
|
|
v_q = mx.zeros(packed_shape, dtype=mx.uint32)
|
|
v_s = mx.zeros(group_shape, dtype=dtype)
|
|
v_b = mx.zeros(group_shape, dtype=dtype)
|
|
|
|
return (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits)
|
|
|
|
|
|
class AttentionABC(ABC, nn.Module):
|
|
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int, *args, **kwds):
|
|
super().__init__()
|
|
|
|
self.n_head = n_head
|
|
self.hidden_dim = hidden_dim
|
|
assert hidden_dim % n_head == 0
|
|
self.head_dim = hidden_dim // n_head
|
|
|
|
self.max_seq_length = max_seq_length
|
|
|
|
# key, query, value projections for all heads, but in a batch
|
|
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
|
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
|
|
|
self.scale = 1 / math.sqrt(self.head_dim)
|
|
|
|
self.kc_class: KVCacheProtocol
|
|
|
|
@abstractmethod
|
|
def __call__(
|
|
self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array
|
|
) -> Array: ...
|
|
|
|
def prefill(self, x: Array, kv_cache: KVCache | KVCacheQ, attn_mask: Array):
|
|
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
|
|
|
|
q, k, v = self.in_proj(x).split(3, axis=-1)
|
|
|
|
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
|
|
|
self.kc_class.prefill_kv(k, v, kv_cache)
|
|
|
|
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
|
|
|
attn = mx.fast.scaled_dot_product_attention(q, k, v, mask=attn_mask, scale=self.scale)
|
|
|
|
attn = mx.nan_to_num(attn)
|
|
|
|
attn = attn.swapaxes(1, 2).reshape(1, -1, self.hidden_dim)
|
|
|
|
output = self.out_proj(attn)
|
|
|
|
return output
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, dim: int, hidden_dim: int) -> None:
|
|
super().__init__()
|
|
|
|
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
|
|
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
|
|
|
|
def __call__(self, x: Array):
|
|
return self.linear2(nn.relu(self.linear1(x)))
|
|
|
|
|
|
class TransformerBlockABC(nn.Module):
|
|
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int, *args, **kwds) -> None:
|
|
super().__init__()
|
|
|
|
self.hidden_dim = hidden_dim
|
|
self.max_seq_length = max_seq_length
|
|
|
|
self.attention: AttentionABC
|
|
|
|
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
|
self.attention_norm = nn.LayerNorm(self.hidden_dim)
|
|
self.ffn_norm = nn.LayerNorm(self.hidden_dim)
|
|
|
|
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
|
h = self.attention_norm(
|
|
x
|
|
+ self.attention(
|
|
x,
|
|
input_pos,
|
|
kv_cache,
|
|
cache_idx,
|
|
attn_mask,
|
|
)
|
|
)
|
|
out = self.ffn_norm(h + self.feed_forward(h))
|
|
return out
|
|
|
|
def prefill(self, x: Array, attn_mask: Array, kv_cache: KVCache | KVCacheQ):
|
|
h = self.attention_norm(
|
|
x
|
|
+ self.attention.prefill(
|
|
x,
|
|
kv_cache,
|
|
attn_mask,
|
|
)
|
|
)
|
|
out = self.ffn_norm(h + self.feed_forward(h))
|
|
|
|
return out
|
|
|
|
|
|
class TransformerDecoderABC(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_dim: int,
|
|
n_layer: int,
|
|
n_head: int,
|
|
ffn_dim: int,
|
|
vocab_size: int,
|
|
max_seq_length: int,
|
|
max_batch_size: int,
|
|
*args,
|
|
**kwds,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.hidden_dim = hidden_dim
|
|
self.n_head = n_head
|
|
assert hidden_dim % n_head == 0
|
|
|
|
self.head_dim = hidden_dim // n_head
|
|
self.vocab_size = vocab_size
|
|
|
|
self.n_layer = n_layer
|
|
|
|
self.layers: MutableSequence[TransformerBlockABC]
|
|
|
|
self.max_seq_length = max_seq_length
|
|
self.max_batch_size = max_batch_size
|
|
|
|
def __call__(
|
|
self,
|
|
input_pos: Array,
|
|
x: Array,
|
|
kv_caches: MutableSequence[KVCache | KVCacheQ],
|
|
cache_idx: Array,
|
|
*args,
|
|
**kwds,
|
|
):
|
|
for layer, kv_cache in zip(self.layers, kv_caches):
|
|
x = layer(
|
|
x,
|
|
input_pos,
|
|
kv_cache,
|
|
cache_idx,
|
|
*args,
|
|
**kwds,
|
|
)
|
|
|
|
return x
|
|
|
|
def prefill(self, x: Array, mask: Array, kv_caches: MutableSequence[KVCache | KVCacheQ]):
|
|
for layer, kv_cache in zip(self.layers, kv_caches):
|
|
x = layer.prefill(
|
|
x,
|
|
mask,
|
|
kv_cache,
|
|
)
|
|
return x
|
|
|
|
|
|
class T2SDecoderABC(nn.Module, T2SDecoderProtocol):
|
|
def __init__(
|
|
self,
|
|
config: dict,
|
|
max_seq_length: int = 2000,
|
|
max_batch_size: int = 10,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
hidden_dim: int = config["model"]["hidden_dim"]
|
|
embedding_dim: int = config["model"]["embedding_dim"]
|
|
n_head: int = config["model"]["head"]
|
|
n_layer: int = config["model"]["n_layer"]
|
|
vocab_size: int = config["model"]["vocab_size"]
|
|
phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
|
|
EOS: int = config["model"]["EOS"]
|
|
ffn_dim: int = hidden_dim * 4
|
|
|
|
self.n_layer = int(n_layer)
|
|
self.hidden_dim = int(hidden_dim)
|
|
self.n_head = int(n_head)
|
|
assert hidden_dim % n_head == 0
|
|
|
|
self.head_dim = int(hidden_dim // n_head)
|
|
self.embedding_dim = int(embedding_dim)
|
|
self.ffn_dim = int(ffn_dim)
|
|
self.vocab_size = int(vocab_size)
|
|
self.phoneme_vocab_size = int(phoneme_vocab_size)
|
|
self.max_seq_length = max_seq_length
|
|
self.max_batch_size = max_batch_size
|
|
self.EOS = EOS
|
|
assert self.EOS == self.vocab_size - 1
|
|
|
|
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
|
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
|
self.h: TransformerDecoderABC
|
|
|
|
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
|
|
self.ar_text_position = SinePositionalEmbedding(
|
|
self.embedding_dim,
|
|
scale=False,
|
|
max_batch_size=max_batch_size,
|
|
max_seq_len=max_seq_length,
|
|
)
|
|
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
|
|
self.ar_audio_position = SinePositionalEmbedding(
|
|
self.embedding_dim,
|
|
scale=False,
|
|
max_batch_size=max_batch_size,
|
|
max_seq_len=max_seq_length,
|
|
)
|
|
|
|
self.kv_class: KVCacheProtocol
|
|
|
|
def init_cache(self, bsz: int = 0, *args, **kwds) -> MutableSequence[KVCache | KVCacheQ]:
|
|
bsz = bsz or self.h.max_batch_size
|
|
assert bsz <= self.h.max_batch_size
|
|
seq_lens = self.h.max_seq_length
|
|
dtype = self.bert_proj.bias.dtype
|
|
cache: MutableSequence[KVCache | KVCacheQ] = [
|
|
self.kv_class.init_cache(bsz, seq_lens, self.n_head, self.head_dim, dtype, *args, **kwds)
|
|
for _ in range(self.n_layer)
|
|
]
|
|
mx.eval(cache)
|
|
return cache
|
|
|
|
def embed(
|
|
self,
|
|
x: list[Array],
|
|
y: Array,
|
|
bert_features: list[Array],
|
|
):
|
|
x_len: list[int] = [cast(tuple[int, ...], i.shape)[0] for i in x]
|
|
x_len_max = max(x_len)
|
|
xy_pos = mx.zeros((len(x), x_len_max + cast(tuple[int, ...], y.shape)[1], self.embedding_dim)).astype(
|
|
bert_features[0].dtype
|
|
)
|
|
|
|
bert_features = list(map(lambda x: x.swapaxes(0, 1), bert_features))
|
|
|
|
y_len = cast(tuple[int, ...], y.shape)[1]
|
|
y_emb = self.ar_audio_embedding(y)
|
|
y_pos = self.ar_audio_position.prefill(y_emb)
|
|
|
|
for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
|
|
x_emb = self.ar_text_embedding(x_)
|
|
bert = self.bert_proj(bert_feature)
|
|
x_emb = x_emb + bert
|
|
x_pos = self.ar_text_position.prefill(mx.expand_dims(x_emb, 0))
|
|
xy_pos[[bs], :len_] = x_pos
|
|
xy_pos[[bs], len_ : len_ + y_len] = y_pos
|
|
|
|
mx.eval(xy_pos)
|
|
return xy_pos
|
|
|
|
def compile(self):
|
|
setattr(self.h, "__call__", mx.compile(self.h.__call__))
|
|
# setattr(self.h, "prefill", mx.compile(self.h.prefill, shapeless=True))
|
|
|
|
def pre_forward(self, session: T2SSessionMLX):
|
|
attn_mask = session.attn_mask
|
|
return list(), dict(attn_mask=attn_mask)
|
|
|
|
def post_forward(self, idx: int, session: T2SSessionMLX) -> None:
|
|
if idx == 0:
|
|
prefill_len = session.prefill_len
|
|
bsz = session.bsz
|
|
|
|
range_tensor = mx.arange(self.max_seq_length).reshape(1, 1, 1, self.max_seq_length)
|
|
prefill_len_expanded = prefill_len.reshape(bsz, 1, 1, 1)
|
|
attn_mask = range_tensor < prefill_len_expanded
|
|
attn_mask = mx.repeat(attn_mask, self.n_head, 1)
|
|
|
|
session.attn_mask = attn_mask
|
|
|
|
attn_mask = session.attn_mask
|
|
input_pos = session.input_pos
|
|
attn_mask[mx.arange(session.bsz), :, :, input_pos] = True
|
|
mx.eval(attn_mask)
|