mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 08:49:59 +08:00
153 lines
4.6 KiB
Python
153 lines
4.6 KiB
Python
"""
|
|
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import List, MutableSequence, Protocol, TypeAlias, cast
|
|
|
|
import mlx.core as mx
|
|
import torch
|
|
|
|
from ..PyTorch.structs import T2SRequest
|
|
from .sample_funcs_mlx import SampleProtocolMLX, sample_naive
|
|
|
|
Tensor = torch.Tensor
|
|
Array = mx.array
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class T2SRequestMLX:
|
|
x: List[Array]
|
|
x_lens: Array
|
|
prompts: Array
|
|
bert_feature: List[Array]
|
|
valid_length: int
|
|
top_k: int = 5
|
|
top_p: float = 1
|
|
early_stop_num: int = -1
|
|
temperature: float = 1.0
|
|
repetition_penalty: float = 1.35
|
|
|
|
@classmethod
|
|
def from_torch(cls, request: T2SRequest) -> T2SRequestMLX:
|
|
x = list(map(lambda tensor: mx.array(tensor.cpu()), request.x))
|
|
x_lens = mx.array(request.x_lens.cpu())
|
|
prompts = mx.array(request.prompts.cpu())
|
|
bert_feature = list(map(lambda tensor: mx.array(tensor.cpu()), request.bert_feature))
|
|
|
|
return cls(
|
|
x,
|
|
x_lens,
|
|
prompts,
|
|
bert_feature,
|
|
request.valid_length,
|
|
request.top_k,
|
|
request.top_p,
|
|
request.early_stop_num,
|
|
request.temperature,
|
|
request.repetition_penalty,
|
|
)
|
|
|
|
|
|
KVCache: TypeAlias = tuple[Array, Array]
|
|
KVCacheQ: TypeAlias = tuple[tuple[Array, Array, Array], tuple[Array, Array, Array], tuple[int, int]]
|
|
|
|
|
|
class KVCacheProtocol(Protocol):
|
|
@staticmethod
|
|
def empty(kv_cache: KVCache | KVCacheQ) -> None: ...
|
|
|
|
@staticmethod
|
|
def update_cache(
|
|
input_pos: Array, k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array
|
|
) -> KVCache | KVCacheQ: ...
|
|
|
|
@staticmethod
|
|
def prefill_kv(k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ) -> None: ...
|
|
|
|
@staticmethod
|
|
def init_cache(
|
|
batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype, *args, **kwds
|
|
) -> KVCache | KVCacheQ: ...
|
|
|
|
|
|
class T2SDecoderProtocol(Protocol):
|
|
max_seq_length: int
|
|
EOS: int
|
|
n_head: int
|
|
|
|
def embed(self, x: list[Array], y: Array, bert_features: list[Array]) -> Array: ...
|
|
|
|
|
|
class T2SSessionMLX:
|
|
def __init__(
|
|
self,
|
|
decoder: T2SDecoderProtocol,
|
|
request_torch: T2SRequest,
|
|
sample_func: type[SampleProtocolMLX] = sample_naive,
|
|
device: mx.Device = mx.Device(mx.cpu),
|
|
dtype: mx.Dtype = mx.float32,
|
|
):
|
|
with mx.stream(device):
|
|
request = T2SRequestMLX.from_torch(request_torch)
|
|
|
|
self.decoder = decoder
|
|
self.request = request
|
|
self.device = device
|
|
self.dtype = dtype
|
|
|
|
bsz = len(request.x)
|
|
y_len: int = cast(tuple[int, ...], request.prompts.shape)[-1]
|
|
self.bsz = bsz
|
|
self.y_len = y_len
|
|
|
|
# Cache
|
|
self.kv_cache: MutableSequence[KVCache | KVCacheQ]
|
|
self.sample = sample_func()
|
|
|
|
# Forward args
|
|
self.x = [i.astype(mx.int32) for i in request.x]
|
|
self.x_lens = request.x_lens.astype(mx.int32)
|
|
self.y = mx.zeros((bsz, decoder.max_seq_length)).astype(mx.int32)
|
|
self.y[:, : cast(tuple[int, ...], request.prompts.shape)[-1]] = request.prompts.astype(mx.int32)
|
|
self.bert_feature = [i.astype(dtype) for i in request.bert_feature]
|
|
|
|
self.prefill_len = self.x_lens + cast(tuple[int, ...], request.prompts.shape)[1]
|
|
|
|
self.input_pos = mx.zeros_like(self.prefill_len)
|
|
self.input_pos += self.prefill_len
|
|
|
|
# EOS
|
|
self.completed = mx.array([False] * len(self.x)).astype(mx.bool_)
|
|
self.y_results: List[Array] = [None] * len(self.x) # type: ignore
|
|
|
|
self.xy_pos = decoder.embed(self.x, request.prompts, self.bert_feature)
|
|
|
|
max_len = int(self.prefill_len.max(-1))
|
|
attn_mask = mx.zeros(shape=(bsz, max_len, max_len), dtype=mx.bool_)
|
|
|
|
for bs in range(bsz):
|
|
pos = int(self.x_lens[bs])
|
|
seq_len = pos + y_len
|
|
|
|
attn_mask[bs, :seq_len, :pos] = True
|
|
|
|
ar_mask = ~mx.triu(
|
|
x=mx.ones(
|
|
shape=(
|
|
y_len,
|
|
y_len,
|
|
),
|
|
dtype=mx.bool_,
|
|
),
|
|
k=1,
|
|
)
|
|
attn_mask[bs, pos:seq_len, pos:seq_len] = ar_mask
|
|
|
|
attn_mask = mx.repeat(mx.expand_dims(attn_mask, 1), decoder.n_head, 1)
|
|
self.attn_mask = attn_mask
|
|
|
|
mx.eval(self.attn_mask)
|