GPT-SoVITS/GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py
2025-09-06 22:58:58 +08:00

66 lines
2.2 KiB
Python

from typing import Protocol, cast
import mlx.core as mx
Array = mx.array
class SampleProtocolMLX(Protocol):
@staticmethod
def __call__(
logits: Array,
previous_tokens: Array,
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
) -> Array: ...
class sample_naive(SampleProtocolMLX):
# @partial(mx.compile)
@staticmethod
def __call__(
logits,
previous_tokens,
temperature,
top_k,
top_p,
repetition_penalty,
):
if temperature <= 1e-5:
probs = mx.softmax(logits, axis=-1)
return mx.argmax(probs, axis=-1, keepdims=True).astype(mx.int32)
if repetition_penalty != 1.0:
batch_idx = mx.arange(cast(tuple[int, ...], previous_tokens.shape)[0])
previous_tokens = previous_tokens.astype(mx.int64)
selected_logists = logits[batch_idx, previous_tokens]
selected_logists = mx.where(
selected_logists < 0, selected_logists * repetition_penalty, selected_logists / repetition_penalty
)
logits[batch_idx, previous_tokens] = selected_logists
if top_p < 1.0:
sorted_indices = mx.argsort(-logits, axis=-1)
sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1)
cum_probs = mx.cumsum(mx.softmax(sorted_logits, axis=-1), axis=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[:, -1] = False
indices_to_remove = mx.zeros_like(logits).astype(mx.bool_)
batch_indices = mx.arange(cast(tuple[int, ...], logits.shape)[0])[:, None]
indices_to_remove[batch_indices, sorted_indices] = sorted_indices_to_remove
logits = mx.where(indices_to_remove, -mx.inf, logits)
if temperature < 1.0:
logits = logits / temperature
v = mx.topk(logits, top_k)
pivot = mx.expand_dims(v[:, 0], -1)
logits = mx.where(logits < pivot, -mx.inf, logits)
gumbel_noise = mx.random.gumbel(shape=cast(tuple[int, ...], logits.shape), dtype=logits.dtype)
idx_next = mx.argmax(logits + gumbel_noise, axis=-1, keepdims=True).astype(mx.int32)
return idx_next