GPT-SoVITS/GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py
XXXXRT666 26d5eaf1b4 .
2025-09-08 19:30:35 +08:00

531 lines
16 KiB
Python

from __future__ import annotations
import math
from abc import ABC, abstractmethod
from typing import MutableSequence, cast
import mlx.core as mx
import mlx.nn as nn
from .structs_mlx import KVCache, KVCacheProtocol, KVCacheQ, T2SDecoderProtocol, T2SSessionMLX
Array = mx.array
class TokenEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
vocab_size: int,
):
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
@property
def weight(self):
return self.word_embeddings.weight
def embedding(self, index: int):
return self.word_embeddings.weight[index : index + 1]
def __call__(self, x: Array):
x = self.word_embeddings(x)
return x
class SinePositionalEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
scale: bool = False,
max_batch_size: int = 10,
max_seq_len: int = 2000,
):
super().__init__()
self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
self.alpha = mx.ones(1)
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.reverse = False
self._pe = mx.zeros((max_batch_size, max_seq_len, embedding_dim))
self.compute_pe()
def compute_pe(self):
"""Reset the positional encodings."""
if self.reverse:
position = mx.expand_dims(mx.arange(self.max_seq_len - 1, -1, -1.0), axis=1)
else:
position = mx.expand_dims(mx.arange(self.max_seq_len), axis=1)
div_term = mx.exp(
mx.arange(
0,
self.embedding_dim,
2,
)
* -(math.log(10000.0) / self.embedding_dim)
)
pe = self._pe
pe[:, :, 0::2] = mx.sin(position * div_term)
pe[:, :, 1::2] = mx.cos(position * div_term)
def __call__(self, input_pos: Array, x: Array):
"""
Args:
input_pos (Array): [batch_size, ]
x (Array): [batch_size, 1, embed_dim]
Returns:
embedded_x (Array): [batch_size, 1, embed_dim]
"""
batch_size = cast(tuple[int, ...], x.shape)[0]
pe_values = self._pe[mx.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
return x * self.x_scale + self.alpha * mx.expand_dims(pe_values, 1) # (batch_size, 1, embed_dim)
def prefill(self, x: Array):
"""
Args:
x (Array): [batch_size, seq_len, embed_dim]
Returns:
embedded_x (Array): [batch_size, seq_len, embed_dim]
"""
pe_values = self._pe[:, : cast(tuple[int, ...], x.shape)[-2]]
return x * self.x_scale + self.alpha * pe_values
class KVCacheHND(KVCacheProtocol):
@staticmethod
def empty(kv_cache):
assert len(kv_cache) == 2
k_cache, v_cache = kv_cache
k_cache[:] = 0
v_cache[:] = 0
@staticmethod
def update_cache(input_pos, k_val, v_val, kv_cache, cache_idx):
# input_pos: [B, ], k_val: [B, H, 1, D]
assert len(kv_cache) == 2
k_out, v_out = kv_cache
ip0 = input_pos - 1
k_out[cache_idx, :, ip0, None] = k_val
v_out[cache_idx, :, ip0, None] = v_val
return k_out, v_out
@staticmethod
def prefill_kv(k_val, v_val, kv_cache):
# k_val: [B, S, H, D]
assert len(kv_cache) == 2
k_cache, v_cache = kv_cache
k_cache[..., : cast(tuple[int, ...], k_val.shape)[1], :] = k_val.swapaxes(1, 2)
v_cache[..., : cast(tuple[int, ...], v_val.shape)[1], :] = v_val.swapaxes(1, 2)
@staticmethod
def init_cache(batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype) -> KVCache:
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
return (mx.zeros(cache_shape, dtype=dtype), mx.zeros(cache_shape, dtype=dtype))
class KVCacheHNDQuantized(KVCacheProtocol):
@staticmethod
def _el_per_int(bits: int) -> int:
return 32 // bits
@staticmethod
def _packed_dim(head_dim: int, bits: int = 8) -> int:
el_per_int = KVCacheHNDQuantized._el_per_int(bits)
if head_dim % el_per_int != 0:
raise ValueError(f"{head_dim=} is not divisible by {el_per_int=} ({bits=})")
return head_dim // el_per_int
@staticmethod
def _group_count(head_dim: int, group_size: int = 32) -> int:
assert group_size in {32, 64, 128}
if head_dim % group_size != 0:
raise ValueError(f"{head_dim} is not divisible by {group_size=}")
return head_dim // group_size
@staticmethod
def empty(kv_cache) -> None:
assert len(kv_cache) == 3
(k_q, k_s, k_b), (v_q, v_s, v_b), (_, __) = kv_cache
k_q[:] = 0
k_s[:] = 0
k_b[:] = 0
v_q[:] = 0
v_s[:] = 0
v_b[:] = 0
@staticmethod
def update_cache(
input_pos,
k_val,
v_val,
kv_cache,
cache_idx,
):
# input_pos: [B, ], k_val: [B, H, 1, D]
assert len(kv_cache) == 3
(k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
k_q, k_s, k_b = mx.quantize(k_val, group_size=group_size, bits=bits)
v_q, v_s, v_b = mx.quantize(v_val, group_size=group_size, bits=bits)
ip0 = input_pos - 1
k_q_out[cache_idx, :, ip0, None] = k_q
k_s_out[cache_idx, :, ip0, None] = k_s
k_b_out[cache_idx, :, ip0, None] = k_b
v_q_out[cache_idx, :, ip0, None] = v_q
v_s_out[cache_idx, :, ip0, None] = v_s
v_b_out[cache_idx, :, ip0, None] = v_b
return (k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits)
@staticmethod
def prefill_kv(
k_val,
v_val,
kv_cache,
) -> None:
assert len(kv_cache) == 3
(k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
S = cast(tuple[int, ...], k_val.shape)[1]
k_sw = k_val.swapaxes(1, 2)
v_sw = v_val.swapaxes(1, 2)
k_q, k_s, k_b = mx.quantize(k_sw, group_size=group_size, bits=bits)
v_q, v_s, v_b = mx.quantize(v_sw, group_size=group_size, bits=bits)
k_q_out[..., :S, :] = k_q
k_s_out[..., :S, :] = k_s
k_b_out[..., :S, :] = k_b
v_q_out[..., :S, :] = v_q
v_s_out[..., :S, :] = v_s
v_b_out[..., :S, :] = v_b
@staticmethod
def init_cache(
batch_size: int,
max_seq_length: int,
n_heads: int,
head_dim: int,
dtype: mx.Dtype,
*,
group_size: int = 32,
bits: int = 8,
) -> KVCacheQ:
packed_dim = KVCacheHNDQuantized._packed_dim(head_dim, bits=bits)
group_cnt = KVCacheHNDQuantized._group_count(head_dim, group_size=group_size)
packed_shape = (batch_size, n_heads, max_seq_length, packed_dim)
group_shape = (batch_size, n_heads, max_seq_length, group_cnt)
k_q = mx.zeros(packed_shape, dtype=mx.uint32)
k_s = mx.zeros(group_shape, dtype=dtype)
k_b = mx.zeros(group_shape, dtype=dtype)
v_q = mx.zeros(packed_shape, dtype=mx.uint32)
v_s = mx.zeros(group_shape, dtype=dtype)
v_b = mx.zeros(group_shape, dtype=dtype)
return (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits)
class AttentionABC(ABC, nn.Module):
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int, *args, **kwds):
super().__init__()
self.n_head = n_head
self.hidden_dim = hidden_dim
assert hidden_dim % n_head == 0
self.head_dim = hidden_dim // n_head
self.max_seq_length = max_seq_length
# key, query, value projections for all heads, but in a batch
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
self.scale = 1 / math.sqrt(self.head_dim)
self.kc_class: KVCacheProtocol
@abstractmethod
def __call__(
self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array
) -> Array: ...
def prefill(self, x: Array, kv_cache: KVCache | KVCacheQ, attn_mask: Array):
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
q, k, v = self.in_proj(x).split(3, axis=-1)
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
self.kc_class.prefill_kv(k, v, kv_cache)
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
attn = mx.fast.scaled_dot_product_attention(q, k, v, mask=attn_mask, scale=self.scale)
attn = mx.nan_to_num(attn)
attn = attn.swapaxes(1, 2).reshape(1, -1, self.hidden_dim)
output = self.out_proj(attn)
return output
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int) -> None:
super().__init__()
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
def __call__(self, x: Array):
return self.linear2(nn.relu(self.linear1(x)))
class TransformerBlockABC(nn.Module):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int, *args, **kwds) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.max_seq_length = max_seq_length
self.attention: AttentionABC
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
self.attention_norm = nn.LayerNorm(self.hidden_dim)
self.ffn_norm = nn.LayerNorm(self.hidden_dim)
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
h = self.attention_norm(
x
+ self.attention(
x,
input_pos,
kv_cache,
cache_idx,
attn_mask,
)
)
out = self.ffn_norm(h + self.feed_forward(h))
return out
def prefill(self, x: Array, attn_mask: Array, kv_cache: KVCache | KVCacheQ):
h = self.attention_norm(
x
+ self.attention.prefill(
x,
kv_cache,
attn_mask,
)
)
out = self.ffn_norm(h + self.feed_forward(h))
return out
class TransformerDecoderABC(nn.Module):
def __init__(
self,
hidden_dim: int,
n_layer: int,
n_head: int,
ffn_dim: int,
vocab_size: int,
max_seq_length: int,
max_batch_size: int,
*args,
**kwds,
) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.n_head = n_head
assert hidden_dim % n_head == 0
self.head_dim = hidden_dim // n_head
self.vocab_size = vocab_size
self.n_layer = n_layer
self.layers: MutableSequence[TransformerBlockABC]
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
def __call__(
self,
input_pos: Array,
x: Array,
kv_caches: MutableSequence[KVCache | KVCacheQ],
cache_idx: Array,
*args,
**kwds,
):
for layer, kv_cache in zip(self.layers, kv_caches):
x = layer(
x,
input_pos,
kv_cache,
cache_idx,
*args,
**kwds,
)
return x
def prefill(self, x: Array, mask: Array, kv_caches: MutableSequence[KVCache | KVCacheQ]):
for layer, kv_cache in zip(self.layers, kv_caches):
x = layer.prefill(
x,
mask,
kv_cache,
)
return x
class T2SDecoderABC(nn.Module, T2SDecoderProtocol):
def __init__(
self,
config: dict,
max_seq_length: int = 2000,
max_batch_size: int = 10,
) -> None:
super().__init__()
hidden_dim: int = config["model"]["hidden_dim"]
embedding_dim: int = config["model"]["embedding_dim"]
n_head: int = config["model"]["head"]
n_layer: int = config["model"]["n_layer"]
vocab_size: int = config["model"]["vocab_size"]
phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
EOS: int = config["model"]["EOS"]
ffn_dim: int = hidden_dim * 4
self.n_layer = int(n_layer)
self.hidden_dim = int(hidden_dim)
self.n_head = int(n_head)
assert hidden_dim % n_head == 0
self.head_dim = int(hidden_dim // n_head)
self.embedding_dim = int(embedding_dim)
self.ffn_dim = int(ffn_dim)
self.vocab_size = int(vocab_size)
self.phoneme_vocab_size = int(phoneme_vocab_size)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
self.EOS = EOS
assert self.EOS == self.vocab_size - 1
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
self.h: TransformerDecoderABC
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim,
scale=False,
max_batch_size=max_batch_size,
max_seq_len=max_seq_length,
)
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim,
scale=False,
max_batch_size=max_batch_size,
max_seq_len=max_seq_length,
)
self.kv_class: KVCacheProtocol
def init_cache(self, bsz: int = 0, *args, **kwds) -> MutableSequence[KVCache | KVCacheQ]:
bsz = bsz or self.h.max_batch_size
assert bsz <= self.h.max_batch_size
seq_lens = self.h.max_seq_length
dtype = self.bert_proj.bias.dtype
cache: MutableSequence[KVCache | KVCacheQ] = [
self.kv_class.init_cache(bsz, seq_lens, self.n_head, self.head_dim, dtype, *args, **kwds)
for _ in range(self.n_layer)
]
mx.eval(cache)
return cache
def embed(
self,
x: list[Array],
y: Array,
bert_features: list[Array],
):
x_len: list[int] = [cast(tuple[int, ...], i.shape)[0] for i in x]
x_len_max = max(x_len)
xy_pos = mx.zeros((len(x), x_len_max + cast(tuple[int, ...], y.shape)[1], self.embedding_dim)).astype(
bert_features[0].dtype
)
bert_features = list(map(lambda x: x.swapaxes(0, 1), bert_features))
y_len = cast(tuple[int, ...], y.shape)[1]
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position.prefill(y_emb)
for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
x_emb = self.ar_text_embedding(x_)
bert = self.bert_proj(bert_feature)
x_emb = x_emb + bert
x_pos = self.ar_text_position.prefill(mx.expand_dims(x_emb, 0))
xy_pos[[bs], :len_] = x_pos
xy_pos[[bs], len_ : len_ + y_len] = y_pos
mx.eval(xy_pos)
return xy_pos
def compile(self):
setattr(self.h, "__call__", mx.compile(self.h.__call__))
# setattr(self.h, "prefill", mx.compile(self.h.prefill, shapeless=True))
def pre_forward(self, session: T2SSessionMLX):
attn_mask = session.attn_mask
return list(), dict(attn_mask=attn_mask)
def post_forward(self, idx: int, session: T2SSessionMLX) -> None:
if idx == 0:
prefill_len = session.prefill_len
bsz = session.bsz
range_tensor = mx.arange(self.max_seq_length).reshape(1, 1, 1, self.max_seq_length)
prefill_len_expanded = prefill_len.reshape(bsz, 1, 1, 1)
attn_mask = range_tensor < prefill_len_expanded
attn_mask = mx.repeat(attn_mask, self.n_head, 1)
session.attn_mask = attn_mask
attn_mask = session.attn_mask
input_pos = session.input_pos
attn_mask[mx.arange(session.bsz), :, :, input_pos] = True
mx.eval(attn_mask)