From c1a4ff476c7050535db2289bee58ef0ca392f7ae Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Sun, 19 Oct 2025 21:51:54 +0100 Subject: [PATCH] . --- .github/build_windows_packages.ps1 | 4 +- .gitignore | 1 + Docker/miniconda_install.sh | 4 +- GPT_SoVITS/Accelerate/MLX/__init__.py | 6 +- .../Accelerate/MLX/backends/mlx_quantized.py | 179 ---- .../Accelerate/MLX/backends/mlx_static.py | 17 +- .../Accelerate/MLX/backends/mlx_varlen.py | 9 +- GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py | 93 +- GPT_SoVITS/Accelerate/MLX/structs_mlx.py | 18 +- GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py | 206 ++-- GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py | 224 ++--- GPT_SoVITS/Accelerate/PyTorch/__init__.py | 35 +- .../backends/flash_attn_varlen_cuda_graph.py | 2 +- .../PyTorch/backends/mps_flash_attn_varlen.py | 2 +- .../backends/sage_attn_varlen_cuda_graph.py | 2 +- .../backends/torch_static_cuda_graph.py | 2 +- .../PyTorch/backends/torch_varlen.py | 2 +- GPT_SoVITS/Accelerate/PyTorch/quantization.py | 158 ++++ GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py | 115 ++- GPT_SoVITS/Accelerate/PyTorch/structs.py | 4 +- GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py | 113 +-- .../Accelerate/PyTorch/t2s_model_abc.py | 56 +- GPT_SoVITS/Accelerate/__init__.py | 13 +- GPT_SoVITS/Accelerate/logger.py | 46 + GPT_SoVITS/inference_webui.py | 24 +- GPT_SoVITS/text/chinese2.py | 4 +- README.md | 4 +- docs/cn/README.md | 4 +- docs/ja/README.md | 4 +- docs/ko/README.md | 4 +- docs/tr/README.md | 4 +- install.ps1 | 6 +- install.sh | 10 +- requirements.txt | 1 + test.py | 879 ++++++++++++++++++ webui.py | 51 +- 36 files changed, 1660 insertions(+), 646 deletions(-) delete mode 100644 GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py create mode 100644 GPT_SoVITS/Accelerate/PyTorch/quantization.py create mode 100644 test.py diff --git a/.github/build_windows_packages.ps1 b/.github/build_windows_packages.ps1 index 606e6089..3a967ef0 100644 --- a/.github/build_windows_packages.ps1 +++ b/.github/build_windows_packages.ps1 @@ -157,12 +157,12 @@ Write-Host "[INFO] Installing PyTorch..." switch ($cuda) { "cu126" { & ".\runtime\python.exe" -m pip install psutil ninja packaging wheel "setuptools>=42" --no-warn-script-location --no-cache-dir - & ".\runtime\python.exe" -m pip install torch --index-url https://download.pytorch.org/whl/cu126 --no-warn-script-location --no-cache-dir + & ".\runtime\python.exe" -m pip install torch torchao --index-url https://download.pytorch.org/whl/cu126 --no-warn-script-location --no-cache-dir & ".\runtime\python.exe" -m pip install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation --no-cache-dir } "cu128" { & ".\runtime\python.exe" -m pip install psutil ninja packaging wheel "setuptools>=42" --no-warn-script-location --no-cache-dir - & ".\runtime\python.exe" -m pip install torch --index-url https://download.pytorch.org/whl/cu128 --no-warn-script-location --no-cache-dir + & ".\runtime\python.exe" -m pip install torch torchao --index-url https://download.pytorch.org/whl/cu128 --no-warn-script-location --no-cache-dir & ".\runtime\python.exe" -m pip install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation --no-cache-dir } default { diff --git a/.gitignore b/.gitignore index d9e61e9f..2c8dfe98 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ tools/AP_BWE/24kto48k/* !tools/AP_BWE/24kto48k/readme.txt onnx_export compile_cache +profiler # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/Docker/miniconda_install.sh b/Docker/miniconda_install.sh index cf2e3d6f..7873116c 100644 --- a/Docker/miniconda_install.sh +++ b/Docker/miniconda_install.sh @@ -60,10 +60,10 @@ source "$HOME/.bashrc" "$HOME/miniconda3/bin/conda" install gcc=11 gxx ffmpeg cmake make unzip $SYSROOT_PKG "libstdcxx-ng>=11" -q -y if [ "$CUDA_VERSION" = "12.8" ]; then - "$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu128 + "$HOME/miniconda3/bin/pip" install torch torchao --no-cache-dir --index-url https://download.pytorch.org/whl/cu128 "$HOME/miniconda3/bin/conda" install cuda-nvcc=12.8 -c nvidia elif [ "$CUDA_VERSION" = "12.6" ]; then - "$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu126 + "$HOME/miniconda3/bin/pip" install torch torchao --no-cache-dir --index-url https://download.pytorch.org/whl/cu126 "$HOME/miniconda3/bin/conda" install cuda-nvcc=12.6 -c nvidia fi diff --git a/GPT_SoVITS/Accelerate/MLX/__init__.py b/GPT_SoVITS/Accelerate/MLX/__init__.py index 20691fcf..f98c8c40 100644 --- a/GPT_SoVITS/Accelerate/MLX/__init__.py +++ b/GPT_SoVITS/Accelerate/MLX/__init__.py @@ -5,8 +5,10 @@ if importlib.util.find_spec("mlx") is not None and platform.system() == "Darwin" from .sample_funcs_mlx import sample_naive as sample_naive_mlx from .t2s_engine_mlx import T2SEngine as T2SEngineMLX - backends = ["mlx_static", "mlx_quantized_mxfp4", "mlx_quantized_affine", "mlx_varlen"] + backends = ["mlx_static", "mlx_varlen"] else: backends = [] -__all__ = ["T2SEngineMLX", "sample_naive_mlx", "backends"] +quantization_methods_mlx = [None, "MXFP4", "Affine"] + +__all__ = ["T2SEngineMLX", "sample_naive_mlx", "backends", "quantization_methods_mlx"] diff --git a/GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py b/GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py deleted file mode 100644 index 54ceddb5..00000000 --- a/GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py +++ /dev/null @@ -1,179 +0,0 @@ -from __future__ import annotations - -import mlx.core as mx -import mlx.nn as nn - -from ..structs_mlx import KVCacheQ -from ..t2s_model_abc import ( - AttentionABC, - KVCache, - KVCacheHND, - T2SDecoderABC, - TransformerBlockABC, - TransformerDecoderABC, -) - -Array = mx.array - - -class Attention(AttentionABC): - def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int): - super().__init__(n_head, hidden_dim, max_seq_length) - self.kc_class = KVCacheHND - - @staticmethod - def quantized_scaled_dot_product_attention( - queries: Array, - q_keys: tuple[Array, Array, Array], - q_values: tuple[Array, Array, Array], - scale: float, - mask: Array, - group_size: int = 32, - bits: int = 8, - ) -> Array: - queries *= scale - - scores = mx.quantized_matmul(queries, *q_keys, transpose=True, group_size=group_size, bits=bits) - scores = mx.where(mask, scores, -mx.inf) - scores = mx.softmax(scores, axis=-1, precise=True) # type: ignore - out = mx.quantized_matmul(scores, *q_values, transpose=False, group_size=group_size, bits=bits) - - return out - - def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array): - bsz, seqlen, _ = x.shape - - q, k, v = self.in_proj(x).split(3, axis=-1) - - q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v)) - - q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v)) - - kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx) - assert len(kv_cache) == 2 - - max_idx = int(input_pos.max()) - - q, k, v = map(lambda x: x[..., :max_idx, :], (q, *kv_cache)) - - mask = attn_mask[..., :max_idx] - - attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask) - - attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim) - - attn = self.out_proj(attn) - - return attn - - # def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array): - # bsz, seqlen, _ = x.shape - - # q, k, v = self.in_proj(x).split(3, axis=-1) - - # q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v)) - - # q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v)) - - # kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx) - - # assert len(kv_cache) == 3 - # (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits) = kv_cache - - # k_q, k_s, k_b, v_q, v_s, v_b = map(lambda x: x[..., : int(input_pos.max()), :], (k_q, k_s, k_b, v_q, v_s, v_b)) - - # mask = attn_mask[..., : int(input_pos.max())] - - # attn = Attention.quantized_scaled_dot_product_attention( - # q, - # (k_q, k_s, k_b), - # (v_q, v_s, v_b), - # self.scale, - # mask, - # group_size, - # bits, - # ) - - # attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim) - - # output = self.out_proj(attn) - - # return output - - -class TransformerBlock(TransformerBlockABC): - def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int, *args, **kwds) -> None: - super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length, *args, **kwds) - - self.attention = Attention(n_head, hidden_dim, max_seq_length, *args, **kwds) - - -class TransformerDecoder(TransformerDecoderABC): - def __init__( - self, - hidden_dim: int, - n_layer: int, - n_head: int, - ffn_dim: int, - vocab_size: int, - max_seq_length: int, - max_batch_size: int, - *args, - **kwds, - ) -> None: - super().__init__( - hidden_dim, - n_layer, - n_head, - ffn_dim, - vocab_size, - max_seq_length, - max_batch_size, - *args, - **kwds, - ) - - self.layers = [ - TransformerBlock( - n_head, - ffn_dim, - hidden_dim, - max_seq_length, - *args, - **kwds, - ) - for _ in range(n_layer) - ] - - -class T2SDecoder(T2SDecoderABC): - def __init__( - self, - config: dict, - max_seq_length: int = 2000, - max_batch_size: int = 10, - ) -> None: - super().__init__(config, max_seq_length, max_batch_size) - - self.h = TransformerDecoder( - self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size - ) - - self.kv_class = KVCacheHND - self.group_size = 32 - self.bits = 8 - self.mode = "affine" - - def set_mode(self, mode: str): - assert mode in ["affine", "mxfp4"] - self.mode = mode - if self.mode == "mxfp4": - self.bits = 4 - else: - self.bits = 8 - - def quantized(self): - nn.quantize(self, self.group_size, self.bits, mode=self.mode) - # for layer in self.h.layers: - # nn.quantize(layer.feed_forward, self.group_size, self.bits) - # nn.quantize(layer.attention, self.group_size, self.bits) diff --git a/GPT_SoVITS/Accelerate/MLX/backends/mlx_static.py b/GPT_SoVITS/Accelerate/MLX/backends/mlx_static.py index ca402210..e6724eb2 100644 --- a/GPT_SoVITS/Accelerate/MLX/backends/mlx_static.py +++ b/GPT_SoVITS/Accelerate/MLX/backends/mlx_static.py @@ -2,7 +2,7 @@ from __future__ import annotations import mlx.core as mx -from ..structs_mlx import KVCache, KVCacheQ +from ..structs_mlx import KVCache from ..t2s_model_abc import ( AttentionABC, KVCacheHND, @@ -19,23 +19,24 @@ class Attention(AttentionABC): super().__init__(n_head, hidden_dim, max_seq_length) self.kc_class = KVCacheHND - def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array): + def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache, cache_idx: Array, attn_mask: Array): bsz, seqlen, _ = x.shape - q, k, v = self.in_proj(x).split(3, axis=-1) + qkv = self.in_proj(x) - q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v)) + q, k, v = mx.split(qkv, 3, -1) - q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v)) + q = q.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3) + k = k.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3) + v = v.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3) kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx) - assert len(kv_cache) == 2 k, v = kv_cache attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attn_mask) - attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim) + attn = attn.transpose(0, 2, 1, 3).reshape(bsz, seqlen, -1) attn = self.out_proj(attn) @@ -85,7 +86,7 @@ class T2SDecoder(T2SDecoderABC): def __init__( self, config: dict, - max_seq_length: int = 2000, + max_seq_length: int = 1500, max_batch_size: int = 10, ) -> None: super().__init__(config, max_seq_length, max_batch_size) diff --git a/GPT_SoVITS/Accelerate/MLX/backends/mlx_varlen.py b/GPT_SoVITS/Accelerate/MLX/backends/mlx_varlen.py index 92d51dc3..11c73589 100644 --- a/GPT_SoVITS/Accelerate/MLX/backends/mlx_varlen.py +++ b/GPT_SoVITS/Accelerate/MLX/backends/mlx_varlen.py @@ -2,7 +2,7 @@ from __future__ import annotations import mlx.core as mx -from ..structs_mlx import KVCache, KVCacheQ +from ..structs_mlx import KVCache from ..t2s_model_abc import ( AttentionABC, KVCacheHND, @@ -19,7 +19,7 @@ class Attention(AttentionABC): super().__init__(n_head, hidden_dim, max_seq_length) self.kc_class = KVCacheHND - def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array): + def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache, cache_idx: Array, attn_mask: Array): bsz, seqlen, _ = x.shape q, k, v = self.in_proj(x).split(3, axis=-1) @@ -29,7 +29,6 @@ class Attention(AttentionABC): q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v)) kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx) - assert len(kv_cache) == 2 max_idx = int(input_pos.max()) @@ -39,7 +38,7 @@ class Attention(AttentionABC): attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask) - attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim) + attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, -1) attn = self.out_proj(attn) @@ -89,7 +88,7 @@ class T2SDecoder(T2SDecoderABC): def __init__( self, config: dict, - max_seq_length: int = 2000, + max_seq_length: int = 1500, max_batch_size: int = 10, ) -> None: super().__init__(config, max_seq_length, max_batch_size) diff --git a/GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py b/GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py index ee4ada55..dcc02a63 100644 --- a/GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py +++ b/GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Protocol import mlx.core as mx @@ -17,8 +18,56 @@ class SampleProtocolMLX(Protocol): ) -> Array: ... +def apply_repetition_penalty(logits: Array, previous_tokens: Array, repetition_penalty: float): + batch_idx = mx.arange(previous_tokens.shape[0]) + selected_logits = logits[batch_idx, previous_tokens] + selected_logits = mx.where( + selected_logits < 0, selected_logits * repetition_penalty, selected_logits / repetition_penalty + ) + logits[batch_idx, previous_tokens] = selected_logits + return logits + + +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def apply_greedy_sampling(logits: Array): + return mx.argmax(logits, axis=-1, keepdims=True).astype(mx.int32) + + +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def apply_temperature(logits: Array, temperature: float): + return logits / temperature + + +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def apply_top_k(logits: Array, top_k: int): + v = mx.topk(logits, top_k) + pivot = mx.expand_dims(v[:, 0], -1) + logits = mx.where(logits < pivot, -mx.inf, logits) + return logits + + +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def apply_top_p(logits: Array, top_p: float): + sorted_indices = mx.argsort(-logits, axis=-1) + sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1) + cum_probs = mx.cumsum(mx.softmax(sorted_logits, axis=-1), axis=-1) + sorted_indices_to_remove = cum_probs > top_p + sorted_indices_to_remove[:, -1] = False + indices_to_remove = mx.zeros_like(logits).astype(mx.bool_) + batch_indices = mx.arange(logits.shape[0])[:, None] + indices_to_remove[batch_indices, sorted_indices] = sorted_indices_to_remove + logits = mx.where(indices_to_remove, -mx.inf, logits) + return logits + + +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def apply_sampling(logits: Array): + gumbel_noise = mx.random.gumbel(shape=logits.shape, dtype=logits.dtype) + idx_next = mx.argmax(logits + gumbel_noise, axis=-1, keepdims=True).astype(mx.int32) + return idx_next + + class sample_naive(SampleProtocolMLX): - # @partial(mx.compile) @staticmethod def __call__( logits, @@ -28,38 +77,18 @@ class sample_naive(SampleProtocolMLX): top_p, repetition_penalty, ): - if temperature <= 1e-5: - probs = mx.softmax(logits, axis=-1) - return mx.argmax(probs, axis=-1, keepdims=True).astype(mx.int32) - if repetition_penalty != 1.0: - batch_idx = mx.arange(previous_tokens.shape[0]) - previous_tokens = previous_tokens.astype(mx.int64) - selected_logists = logits[batch_idx, previous_tokens] - selected_logists = mx.where( - selected_logists < 0, selected_logists * repetition_penalty, selected_logists / repetition_penalty - ) - logits[batch_idx, previous_tokens] = selected_logists + logits = apply_repetition_penalty(logits, previous_tokens, repetition_penalty) + + if temperature <= 1e-5: + return apply_greedy_sampling(logits) + elif temperature < 1.0: + logits = apply_temperature(logits, temperature) + + if top_k < 1025: + logits = apply_top_k(logits, top_k) if top_p < 1.0: - sorted_indices = mx.argsort(-logits, axis=-1) - sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1) - cum_probs = mx.cumsum(mx.softmax(sorted_logits, axis=-1), axis=-1) - sorted_indices_to_remove = cum_probs > top_p - sorted_indices_to_remove[:, -1] = False - indices_to_remove = mx.zeros_like(logits).astype(mx.bool_) - batch_indices = mx.arange(logits.shape[0])[:, None] - indices_to_remove[batch_indices, sorted_indices] = sorted_indices_to_remove - logits = mx.where(indices_to_remove, -mx.inf, logits) + logits = apply_top_p(logits, top_p) - if temperature < 1.0: - logits = logits / temperature - - v = mx.topk(logits, top_k) - pivot = mx.expand_dims(v[:, 0], -1) - logits = mx.where(logits < pivot, -mx.inf, logits) - - gumbel_noise = mx.random.gumbel(shape=logits.shape, dtype=logits.dtype) - idx_next = mx.argmax(logits + gumbel_noise, axis=-1, keepdims=True).astype(mx.int32) - - return idx_next + return apply_sampling(logits) diff --git a/GPT_SoVITS/Accelerate/MLX/structs_mlx.py b/GPT_SoVITS/Accelerate/MLX/structs_mlx.py index ce2fd003..426309e6 100644 --- a/GPT_SoVITS/Accelerate/MLX/structs_mlx.py +++ b/GPT_SoVITS/Accelerate/MLX/structs_mlx.py @@ -29,6 +29,7 @@ class T2SRequestMLX: early_stop_num: int = -1 temperature: float = 1.0 repetition_penalty: float = 1.35 + debug: bool = False @classmethod def from_torch(cls, request: T2SRequest) -> T2SRequestMLX: @@ -48,29 +49,27 @@ class T2SRequestMLX: request.early_stop_num, request.temperature, request.repetition_penalty, + request.debug, ) KVCache: TypeAlias = tuple[Array, Array] -KVCacheQ: TypeAlias = tuple[tuple[Array, Array, Array], tuple[Array, Array, Array], tuple[int, int]] class KVCacheProtocol(Protocol): @staticmethod - def empty(kv_cache: KVCache | KVCacheQ) -> None: ... + def empty(kv_cache: KVCache) -> None: ... @staticmethod - def update_cache( - input_pos: Array, k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array - ) -> KVCache | KVCacheQ: ... + def update_cache(input_pos: Array, k_val: Array, v_val: Array, kv_cache: KVCache, cache_idx: Array) -> KVCache: ... @staticmethod - def prefill_kv(k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ) -> None: ... + def prefill_kv(k_val: Array, v_val: Array, kv_cache: KVCache) -> None: ... @staticmethod def init_cache( batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype, *args, **kwds - ) -> KVCache | KVCacheQ: ... + ) -> KVCache: ... class T2SDecoderProtocol(Protocol): @@ -104,7 +103,7 @@ class T2SSessionMLX: self.y_len = y_len # Cache - self.kv_cache: MutableSequence[KVCache | KVCacheQ] + self.kv_cache: MutableSequence[KVCache] self.sample = sample_func() # Forward args @@ -118,6 +117,7 @@ class T2SSessionMLX: self.input_pos = mx.zeros_like(self.prefill_len) self.input_pos += self.prefill_len + self.input_pos = self.input_pos.squeeze(0) # 30% Performance Improvement # EOS self.completed = mx.array([False] * len(self.x)).astype(mx.bool_) @@ -148,5 +148,3 @@ class T2SSessionMLX: attn_mask = mx.repeat(mx.expand_dims(attn_mask, 1), decoder.n_head, 1) self.attn_mask = attn_mask - - mx.eval(self.attn_mask) diff --git a/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py b/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py index 13b4c088..d21557cb 100644 --- a/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py +++ b/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py @@ -1,22 +1,23 @@ -import gc import os import time import traceback -from typing import cast +from typing import Literal, cast import mlx.core as mx import torch from rich.progress import BarColumn, Progress, TextColumn -from ..logger import SpeedColumnToken, console, logger +from ..logger import SpeedColumnToken, Timer, console, logger from ..PyTorch.structs import T2SEngineProtocol, T2SRequest, T2SResult -from .backends import mlx_quantized, mlx_static, mlx_varlen +from .backends import mlx_static, mlx_varlen from .structs_mlx import T2SSessionMLX from .t2s_model_abc import T2SDecoderABC Array = mx.array Tensor = torch.Tensor +timer = Timer() + class T2SEngine(T2SEngineProtocol): def __init__( @@ -31,10 +32,10 @@ class T2SEngine(T2SEngineProtocol): device = mx.Device(mx.cpu) case "mx.gpu": device = mx.Device(mx.gpu) - + device = cast(mx.Device, device) match dtype: case torch.float32: - dtype = mx.float32 + dtype = mx.float16 if device.type == mx.gpu else mx.float32 case torch.float16: dtype = mx.float16 case torch.bfloat16: @@ -59,13 +60,14 @@ class T2SEngine(T2SEngineProtocol): decoder = self.decoder_model session = T2SSessionMLX(decoder, request, device=self.device, dtype=self.dtype) batch_idx = mx.arange(session.bsz) + debug = request.debug t1 = 0.0 infer_speed = 0.0 infer_time = 0.0 + idx = 0 with ( - mx.stream(session.device), Progress( TextColumn("[cyan]{task.description}"), BarColumn(), @@ -75,29 +77,47 @@ class T2SEngine(T2SEngineProtocol): transient=True, ) as progress, ): - max_token = min(2000 - int(session.input_pos.max()), 1500) + max_token = min(1500 - int(session.input_pos.max()), 1000) * session.bsz task = progress.add_task("T2S Decoding", total=max_token) - for idx in range(1500): - progress.update(task, advance=1) + for idx in range(max_token): + progress.update(task, advance=session.bsz) if idx == 0: session.kv_cache = decoder.init_cache(session.bsz) - xy_dec = decoder.h.prefill( - session.xy_pos, - session.attn_mask, - session.kv_cache, - ) # bs, seq_len, embed_dim - xy_dec = xy_dec[None, batch_idx, session.input_pos - 1] + t1 = time.perf_counter() + with timer("MLX.Prefill", debug=debug): + xy_dec = decoder.h.prefill( + session.xy_pos, + session.attn_mask, + session.kv_cache, + ) # bs, seq_len, embed_dim + xy_dec = xy_dec[None, batch_idx, session.input_pos - 1] + if debug: + mx.eval(xy_dec) else: args, kwds = decoder.pre_forward(session) - xy_dec = decoder.h( - session.input_pos, - session.xy_pos, - session.kv_cache, - batch_idx, - *args, - **kwds, - ) + + if debug: + mx.eval(session.input_pos, session.xy_pos, session.kv_cache, args, kwds, batch_idx) + + if debug and idx == 50 and os.environ.get("MTL_CAPTURE_ENABLED") == "1": + os.makedirs("./profiler/mlx", exist_ok=True) + mx.metal.start_capture(f"./profiler/mlx/{time.time()}.gputrace") + + with timer("MLX.Decode", debug=debug): + xy_dec = decoder.h( + session.input_pos, + session.xy_pos, + session.kv_cache, + batch_idx, + *args, + **kwds, + ) + if debug: + mx.eval(xy_dec) + + if debug and idx == 50 and os.environ.get("MTL_CAPTURE_ENABLED") == "1": + mx.metal.stop_capture() decoder.post_forward(idx, session) logits = decoder.ar_predict_layer(xy_dec[:, -1]) @@ -106,28 +126,38 @@ class T2SEngine(T2SEngineProtocol): if idx == 0: logits[:, -1] = -mx.inf - samples = session.sample( - logits=logits, - previous_tokens=session.y[:, : session.y_len + idx], - top_k=request.top_k, - top_p=request.top_p, - repetition_penalty=request.repetition_penalty, - temperature=request.temperature, - ) + with timer("MLX.Sampling", debug=debug): + samples = session.sample( + logits=logits, + previous_tokens=session.y[:, : session.y_len + idx], + top_k=request.top_k, + top_p=request.top_p, + repetition_penalty=request.repetition_penalty, + temperature=request.temperature, + ) - session.y[batch_idx, session.y_len + idx] = samples + session.y[batch_idx, session.y_len + idx] = samples - argmax_token = mx.argmax(logits, axis=-1) - sample_token = samples.squeeze(1) - EOS_mask = (cast(Array, argmax_token == decoder.EOS)) | (sample_token == decoder.EOS) + if debug: + mx.eval(samples) - newly_done_mask = EOS_mask & (~session.completed) - newly_done_indices = mx.where(newly_done_mask, batch_idx, -1) - pos = mx.where(newly_done_indices != -1, batch_idx, session.bsz) - pos_sorted = mx.sort(pos, axis=0) - valid_count = session.bsz - mx.sum(cast(Array, pos_sorted == session.bsz)) - pos_final = pos_sorted[: int(valid_count)] - newly_done_indices = mx.expand_dims(newly_done_indices[pos_final], 0) + with timer("MLX.EOS", debug=debug): + mx.set_default_device(mx.Device(mx.cpu)) + argmax_token = mx.argmax(logits, axis=-1) + sample_token = samples.squeeze(1) + EOS_mask = (cast(Array, argmax_token == decoder.EOS)) | (sample_token == decoder.EOS) + + newly_done_mask = EOS_mask & (~session.completed) + newly_done_indices = mx.where(newly_done_mask, batch_idx, -1) + pos = mx.where(newly_done_indices != -1, batch_idx, session.bsz) + pos_sorted = mx.sort(pos, axis=0) + valid_count = session.bsz - mx.sum(cast(Array, pos_sorted == session.bsz)) + pos_final = pos_sorted[: int(valid_count)] + newly_done_indices = mx.expand_dims(newly_done_indices[pos_final], 0) + mx.set_default_device(self.device) + + if debug: + mx.eval(newly_done_indices) if newly_done_indices.size > 0: for i in newly_done_indices: @@ -135,54 +165,53 @@ class T2SEngine(T2SEngineProtocol): session.completed[newly_done_indices] = True if mx.all(session.completed).item(): - if session.y[:, session.y_len :].sum() == 0: - session.y_results = [mx.array([0]) for _ in range(session.bsz)] - logger.error("Bad Zero Prediction") - else: - logger.info( - f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.shape[-1] for i in session.y_results].__str__().strip('[]')}" - ) - logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s") - infer_time = time.perf_counter() - t1 - infer_speed = (idx - 1) / infer_time + logger.info( + f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.shape[-1] for i in session.y_results].__str__().strip('[]')}" + ) + logger.info(f"Infer Speed: {(idx + 1) / (time.perf_counter() - t1):.2f} token/s") + infer_time = time.perf_counter() - t1 + infer_speed = (idx + 1) / infer_time break if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1: for j in range(session.bsz): if not session.completed[j].item(): - session.y_results[j] = session.y[[j], session.y_len : session.y_len + 1499] + session.y_results[j] = session.y[[j], session.y_len : session.y_len + idx] session.completed[j] = True logger.error("Bad Full Prediction") - logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s") + logger.info(f"Infer Speed: {(idx + 1) / (time.perf_counter() - t1):.2f} token/s") infer_time = time.perf_counter() - t1 - infer_speed = (idx - 1) / infer_time + infer_speed = (idx + 1) / infer_time break - y_emb = decoder.ar_audio_embedding(samples) - session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb) - mx.eval(session.xy_pos, session.y) + with timer("MLX.NextPos", debug=debug): + y_emb = decoder.ar_audio_embedding(samples) + session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb) + mx.eval(session.xy_pos, session.y) - if idx == 1: - t1 = time.perf_counter() - - if idx % 100 == 0: + if idx % 128 == 0: mx.clear_cache() - match session.device: - case mx.gpu: - mx.clear_cache() - case mx.cpu: - gc.collect() - result_mlx = session.y_results[: request.valid_length] mx.eval(result_mlx) result = [torch.tensor(k) for k in result_mlx] - return result, infer_speed, infer_time + mx.clear_cache() + + if debug: + timer.summary() + timer.clear() + + return result, infer_speed, infer_time, (idx + 1) * session.bsz def generate(self, request: T2SRequest): try: - result, infer_speed, infer_time = self._handle_request(request) - t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success") + result, infer_speed, infer_time, total_tokens = self._handle_request(request) + t2s_result = T2SResult( + result=result, + infer_speed=(infer_speed, infer_time), + total_tokens=total_tokens, + status="Success", + ) except Exception as e: t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc()) return t2s_result @@ -199,40 +228,39 @@ class T2SEngine(T2SEngineProtocol): .replace("norm1", "attention_norm") .replace("norm2", "ffn_norm") ) - value_mlx = mx.array(value) # type: ignore + value_mlx = mx.array(value.to(torch.float32).cpu().numpy()) state_dict_mlx.append((key, value_mlx)) return state_dict_mlx @staticmethod - def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "MLX-Varlen"): + def load_decoder( + weights_path: os.PathLike, + max_batch_size: int = 1, + backend: str = "MLX-Varlen", + quantize_mode: Literal["Affine", "MXFP4"] | None = None, + ) -> T2SDecoderABC: logger.info(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend") - dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True) + dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=True, mmap=True) config = dict_s1["config"] match backend: case "MLX-Varlen": decoder_cls: type[T2SDecoderABC] = mlx_varlen.T2SDecoder case "MLX-Static": decoder_cls = mlx_static.T2SDecoder - case "MLX-Quantized-Affine" | "MLX-Quantized-MXFP4": - decoder_cls = mlx_quantized.T2SDecoder case _: raise RuntimeError(f"Backend {backend} Not Found") decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=max_batch_size) state_dict = dict_s1["weight"] state_dict_mlx = T2SEngine.replace_key(state_dict) + decoder.load_weights(state_dict_mlx) - decoder.eval() + + if quantize_mode is not None: + decoder.quantize(quantize_mode) + logger.info( + f"Quantized to {decoder.bits}-Bit with Group Size {decoder.group_size} by {quantize_mode} Quantization" + ) + mx.eval(decoder) - - if "Quantized" in backend and isinstance(decoder, mlx_quantized.T2SDecoder): - if backend == "MLX-Quantized-Affine": - decoder.set_mode("affine") - elif backend == "MLX-Quantized-MXFP4": - decoder.set_mode("mxfp4") - else: - raise RuntimeError(f"Quantized Backend {backend} Not Supported") - decoder.quantized() - mx.eval(decoder) - return decoder diff --git a/GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py b/GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py index 0bff4219..80de8f39 100644 --- a/GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py +++ b/GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py @@ -2,12 +2,13 @@ from __future__ import annotations import math from abc import ABC, abstractmethod -from typing import MutableSequence +from typing import Literal, MutableSequence, Type import mlx.core as mx import mlx.nn as nn +from mlx.core import Dtype -from .structs_mlx import KVCache, KVCacheProtocol, KVCacheQ, T2SDecoderProtocol, T2SSessionMLX +from .structs_mlx import KVCache, KVCacheProtocol, T2SDecoderProtocol, T2SSessionMLX Array = mx.array @@ -43,26 +44,28 @@ class SinePositionalEmbedding(nn.Module): embedding_dim: int, scale: bool = False, max_batch_size: int = 10, - max_seq_len: int = 2000, + max_seq_length: int = 1500, ): super().__init__() self.embedding_dim = embedding_dim self.x_scale = math.sqrt(embedding_dim) if scale else 1.0 self.alpha = mx.ones(1) self.max_batch_size = max_batch_size - self.max_seq_len = max_seq_len + self.max_seq_length = max_seq_length self.reverse = False - self._pe = mx.zeros((max_batch_size, max_seq_len, embedding_dim)) - self.compute_pe() + self.pe: Array | None = None - def compute_pe(self): - """Reset the positional encodings.""" + def compute_pe(self, dtype: Dtype): + """Compute the positional encodings.""" + + if self.pe is not None and self.pe.dtype == dtype: + return if self.reverse: - position = mx.expand_dims(mx.arange(self.max_seq_len - 1, -1, -1.0), axis=1) + position = mx.expand_dims(mx.arange(self.max_seq_length - 1, -1, -1.0), axis=1) else: - position = mx.expand_dims(mx.arange(self.max_seq_len), axis=1) + position = mx.expand_dims(mx.arange(self.max_seq_length), axis=1) div_term = mx.exp( mx.arange( 0, @@ -70,10 +73,13 @@ class SinePositionalEmbedding(nn.Module): 2, ) * -(math.log(10000.0) / self.embedding_dim) - ) - pe = self._pe - pe[:, :, 0::2] = mx.sin(position * div_term) - pe[:, :, 1::2] = mx.cos(position * div_term) + ).astype(dtype) + pe = mx.zeros((self.max_batch_size, self.max_seq_length, self.embedding_dim)).astype(dtype) + + pe[:, :, 0::2] = mx.sin(position * div_term).astype(dtype) + pe[:, :, 1::2] = mx.cos(position * div_term).astype(dtype) + + self.pe = pe def __call__(self, input_pos: Array, x: Array): """ @@ -84,9 +90,11 @@ class SinePositionalEmbedding(nn.Module): Returns: embedded_x (Array): [batch_size, 1, embed_dim] """ + self.compute_pe(x.dtype) + assert self.pe is not None batch_size = x.shape[0] - pe_values = self._pe[mx.arange(batch_size), input_pos - 1] # (batch_size, embed_dim) + pe_values = self.pe[mx.arange(batch_size), input_pos - 1] # (batch_size, embed_dim) return x * self.x_scale + self.alpha * mx.expand_dims(pe_values, 1) # (batch_size, 1, embed_dim) @@ -98,7 +106,10 @@ class SinePositionalEmbedding(nn.Module): Returns: embedded_x (Array): [batch_size, seq_len, embed_dim] """ - pe_values = self._pe[:, : x.shape[-2]] + self.compute_pe(x.dtype) + assert self.pe is not None + + pe_values = self.pe[:, : x.shape[-2]] return x * self.x_scale + self.alpha * pe_values @@ -125,12 +136,12 @@ class KVCacheHND(KVCacheProtocol): @staticmethod def prefill_kv(k_val, v_val, kv_cache): - # k_val: [B, S, H, D] + # k_val: [B, H, S, D] assert len(kv_cache) == 2 k_cache, v_cache = kv_cache - k_cache[..., : k_val.shape[1], :] = k_val.swapaxes(1, 2) - v_cache[..., : v_val.shape[1], :] = v_val.swapaxes(1, 2) + k_cache[..., : k_val.shape[2], :] = k_val + v_cache[..., : v_val.shape[2], :] = v_val @staticmethod def init_cache(batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype) -> KVCache: @@ -139,118 +150,6 @@ class KVCacheHND(KVCacheProtocol): return (mx.zeros(cache_shape, dtype=dtype), mx.zeros(cache_shape, dtype=dtype)) -class KVCacheHNDQuantized(KVCacheProtocol): - @staticmethod - def _el_per_int(bits: int) -> int: - return 32 // bits - - @staticmethod - def _packed_dim(head_dim: int, bits: int = 8) -> int: - el_per_int = KVCacheHNDQuantized._el_per_int(bits) - if head_dim % el_per_int != 0: - raise ValueError(f"{head_dim=} is not divisible by {el_per_int=} ({bits=})") - return head_dim // el_per_int - - @staticmethod - def _group_count(head_dim: int, group_size: int = 32) -> int: - assert group_size in {32, 64, 128} - if head_dim % group_size != 0: - raise ValueError(f"{head_dim} is not divisible by {group_size=}") - return head_dim // group_size - - @staticmethod - def empty(kv_cache) -> None: - assert len(kv_cache) == 3 - (k_q, k_s, k_b), (v_q, v_s, v_b), (_, __) = kv_cache - - k_q[:] = 0 - k_s[:] = 0 - k_b[:] = 0 - v_q[:] = 0 - v_s[:] = 0 - v_b[:] = 0 - - @staticmethod - def update_cache( - input_pos, - k_val, - v_val, - kv_cache, - cache_idx, - ): - # input_pos: [B, ], k_val: [B, H, 1, D] - - assert len(kv_cache) == 3 - (k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache - - k_q, k_s, k_b = mx.quantize(k_val, group_size=group_size, bits=bits) - v_q, v_s, v_b = mx.quantize(v_val, group_size=group_size, bits=bits) - - ip0 = input_pos - 1 - - k_q_out[cache_idx, :, ip0, None] = k_q - k_s_out[cache_idx, :, ip0, None] = k_s - k_b_out[cache_idx, :, ip0, None] = k_b - - v_q_out[cache_idx, :, ip0, None] = v_q - v_s_out[cache_idx, :, ip0, None] = v_s - v_b_out[cache_idx, :, ip0, None] = v_b - - return (k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) - - @staticmethod - def prefill_kv( - k_val, - v_val, - kv_cache, - ) -> None: - assert len(kv_cache) == 3 - (k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache - - S = k_val.shape[1] - - k_sw = k_val.swapaxes(1, 2) - v_sw = v_val.swapaxes(1, 2) - - k_q, k_s, k_b = mx.quantize(k_sw, group_size=group_size, bits=bits) - v_q, v_s, v_b = mx.quantize(v_sw, group_size=group_size, bits=bits) - - k_q_out[..., :S, :] = k_q - k_s_out[..., :S, :] = k_s - k_b_out[..., :S, :] = k_b - - v_q_out[..., :S, :] = v_q - v_s_out[..., :S, :] = v_s - v_b_out[..., :S, :] = v_b - - @staticmethod - def init_cache( - batch_size: int, - max_seq_length: int, - n_heads: int, - head_dim: int, - dtype: mx.Dtype, - *, - group_size: int = 32, - bits: int = 8, - ) -> KVCacheQ: - packed_dim = KVCacheHNDQuantized._packed_dim(head_dim, bits=bits) - group_cnt = KVCacheHNDQuantized._group_count(head_dim, group_size=group_size) - - packed_shape = (batch_size, n_heads, max_seq_length, packed_dim) - group_shape = (batch_size, n_heads, max_seq_length, group_cnt) - - k_q = mx.zeros(packed_shape, dtype=mx.uint32) - k_s = mx.zeros(group_shape, dtype=dtype) - k_b = mx.zeros(group_shape, dtype=dtype) - - v_q = mx.zeros(packed_shape, dtype=mx.uint32) - v_s = mx.zeros(group_shape, dtype=dtype) - v_b = mx.zeros(group_shape, dtype=dtype) - - return (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits) - - class AttentionABC(ABC, nn.Module): def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int, *args, **kwds): super().__init__() @@ -271,26 +170,26 @@ class AttentionABC(ABC, nn.Module): self.kc_class: KVCacheProtocol @abstractmethod - def __call__( - self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array - ) -> Array: ... + def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache, cache_idx: Array, attn_mask: Array) -> Array: ... - def prefill(self, x: Array, kv_cache: KVCache | KVCacheQ, attn_mask: Array): + def prefill(self, x: Array, kv_cache: KVCache, attn_mask: Array): bsz, seqlen, _ = x.shape - q, k, v = self.in_proj(x).split(3, axis=-1) + qkv = self.in_proj(x) - q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v)) + q, k, v = mx.split(qkv, 3, -1) + + q = q.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3) + k = k.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3) + v = v.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3) self.kc_class.prefill_kv(k, v, kv_cache) - q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v)) - attn = mx.fast.scaled_dot_product_attention(q, k, v, mask=attn_mask, scale=self.scale) attn = mx.nan_to_num(attn) - attn = attn.swapaxes(1, 2).reshape(1, -1, self.hidden_dim) + attn = attn.transpose(0, 2, 1, 3).reshape(bsz, seqlen, -1) output = self.out_proj(attn) @@ -321,7 +220,7 @@ class TransformerBlockABC(nn.Module): self.attention_norm = nn.LayerNorm(self.hidden_dim) self.ffn_norm = nn.LayerNorm(self.hidden_dim) - def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array): + def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache, cache_idx: Array, attn_mask: Array): h = self.attention_norm( x + self.attention( @@ -335,7 +234,7 @@ class TransformerBlockABC(nn.Module): out = self.ffn_norm(h + self.feed_forward(h)) return out - def prefill(self, x: Array, attn_mask: Array, kv_cache: KVCache | KVCacheQ): + def prefill(self, x: Array, attn_mask: Array, kv_cache: KVCache): h = self.attention_norm( x + self.attention.prefill( @@ -382,7 +281,7 @@ class TransformerDecoderABC(nn.Module): self, input_pos: Array, x: Array, - kv_caches: MutableSequence[KVCache | KVCacheQ], + kv_caches: MutableSequence[KVCache], cache_idx: Array, *args, **kwds, @@ -399,7 +298,7 @@ class TransformerDecoderABC(nn.Module): return x - def prefill(self, x: Array, mask: Array, kv_caches: MutableSequence[KVCache | KVCacheQ]): + def prefill(self, x: Array, mask: Array, kv_caches: MutableSequence[KVCache]): for layer, kv_cache in zip(self.layers, kv_caches): x = layer.prefill( x, @@ -413,7 +312,7 @@ class T2SDecoderABC(nn.Module, T2SDecoderProtocol): def __init__( self, config: dict, - max_seq_length: int = 2000, + max_seq_length: int = 1500, max_batch_size: int = 10, ) -> None: super().__init__() @@ -451,24 +350,27 @@ class T2SDecoderABC(nn.Module, T2SDecoderProtocol): self.embedding_dim, scale=False, max_batch_size=max_batch_size, - max_seq_len=max_seq_length, + max_seq_length=max_seq_length, ) self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size) self.ar_audio_position = SinePositionalEmbedding( self.embedding_dim, scale=False, max_batch_size=max_batch_size, - max_seq_len=max_seq_length, + max_seq_length=max_seq_length, ) - self.kv_class: KVCacheProtocol + self.kv_class: Type[KVCacheProtocol] - def init_cache(self, bsz: int = 0, *args, **kwds) -> MutableSequence[KVCache | KVCacheQ]: + self.bits: int = -1 + self.group_size: int = -1 + + def init_cache(self, bsz: int = 0, *args, **kwds) -> MutableSequence[KVCache]: bsz = bsz or self.h.max_batch_size assert bsz <= self.h.max_batch_size seq_lens = self.h.max_seq_length dtype = self.bert_proj.bias.dtype - cache: MutableSequence[KVCache | KVCacheQ] = [ + cache: MutableSequence[KVCache] = [ self.kv_class.init_cache(bsz, seq_lens, self.n_head, self.head_dim, dtype, *args, **kwds) for _ in range(self.n_layer) ] @@ -503,8 +405,7 @@ class T2SDecoderABC(nn.Module, T2SDecoderProtocol): return xy_pos def compile(self): - setattr(self.h, "__call__", mx.compile(self.h.__call__)) - # setattr(self.h, "prefill", mx.compile(self.h.prefill, shapeless=True)) + setattr(self.h, "__call__", mx.compile(self.h.__call__, shapeless=True)) def pre_forward(self, session: T2SSessionMLX): attn_mask = session.attn_mask @@ -525,4 +426,21 @@ class T2SDecoderABC(nn.Module, T2SDecoderProtocol): attn_mask = session.attn_mask input_pos = session.input_pos attn_mask[mx.arange(session.bsz), :, :, input_pos] = True - mx.eval(attn_mask) + + def quantize(self, mode: Literal["Affine", "MXFP4"] | None = None) -> None: + if mode is None: + return + if mode not in {"Affine", "MXFP4"}: + raise ValueError(f"Unsupported quantization mode: {mode}") + match mode: + case "Affine": + self.bits = 8 + self.group_size = 32 + nn.quantize(self.h, group_size=self.group_size, bits=self.bits, mode="affine") + case "MXFP4": + self.bits = 4 + self.group_size = 32 + nn.quantize(self.h, group_size=self.group_size, bits=self.bits, mode="mxfp4") + + case _: + raise ValueError(f"Unsupported Quantization Mode for MLX: {mode}") diff --git a/GPT_SoVITS/Accelerate/PyTorch/__init__.py b/GPT_SoVITS/Accelerate/PyTorch/__init__.py index 91617fc6..7cef76e7 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/__init__.py +++ b/GPT_SoVITS/Accelerate/PyTorch/__init__.py @@ -7,10 +7,23 @@ from .structs import T2SRequest, T2SResult from .t2s_engine import T2SEngine as T2SEngineTorch torch.set_grad_enabled(False) + +if torch.__version__ >= "2.9.0": + torch.backends.fp32_precision = "tf32" # type: ignore +else: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + +if torch.cuda.is_available(): + torch.backends.cuda.preferred_blas_library("cublaslt") + + +torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) +torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True +torch.backends.cuda.matmul.allow_fp16_accumulation = True +torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.enabled = True -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True backends = ["torch_varlen"] if torch.cuda.is_available(): @@ -30,5 +43,21 @@ if torch.cuda.is_available(): # if torch.mps.is_available(): # backends.append("mps_flash_attn_varlen") +BLACKWELL = False +if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + sm_version = major + minor / 10.0 + if sm_version >= 9.0: + BLACKWELL = True -__all__ = ["T2SEngineTorch", "T2SRequest", "sample_naive", "T2SResult", "backends"] +quantization_methods_torch: list[str | None] = [None] +if importlib.util.find_spec("torchao") is not None: + quantization_methods_torch.append("Int8") + if BLACKWELL: + quantization_methods_torch.append("FP8") +if BLACKWELL: + quantization_methods_torch.append("FP8_E4M3FN") + + +__all__ = ["T2SEngineTorch", "T2SRequest", "sample_naive", "T2SResult", "backends", "quantization_methods_torch"] diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py b/GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py index 6b699198..adf7fceb 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py +++ b/GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py @@ -100,7 +100,7 @@ class T2SDecoder(T2SDecoderABC): def __init__( self, config, - max_seq_length=2000, + max_seq_length=1500, max_batch_size=10, ) -> None: assert torch.cuda.is_available() diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py b/GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py index f8b5d0a1..981a71df 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py +++ b/GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py @@ -78,7 +78,7 @@ class T2SDecoder(T2SDecoderABC): def __init__( self, config, - max_seq_length=2000, + max_seq_length=1500, max_batch_size=10, ) -> None: super().__init__(config, max_seq_length, max_batch_size) diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py b/GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py index ddd150c3..cd98739e 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py +++ b/GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py @@ -94,7 +94,7 @@ class T2SDecoder(T2SDecoderABC): def __init__( self, config, - max_seq_length=2000, + max_seq_length=1500, max_batch_size=10, ) -> None: super().__init__(config, max_seq_length, max_batch_size) diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py b/GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py index 961bc19a..37cae95d 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py +++ b/GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py @@ -78,7 +78,7 @@ class T2SDecoder(T2SDecoderABC): def __init__( self, config, - max_seq_length=2000, + max_seq_length=1500, max_batch_size=10, ) -> None: super().__init__(config, max_seq_length, max_batch_size) diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py b/GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py index 3618376e..cd16c79d 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py +++ b/GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py @@ -86,7 +86,7 @@ class T2SDecoder(T2SDecoderABC): def __init__( self, config, - max_seq_length=2000, + max_seq_length=1500, max_batch_size=10, ) -> None: super().__init__(config, max_seq_length, max_batch_size) diff --git a/GPT_SoVITS/Accelerate/PyTorch/quantization.py b/GPT_SoVITS/Accelerate/PyTorch/quantization.py new file mode 100644 index 00000000..aff25cb5 --- /dev/null +++ b/GPT_SoVITS/Accelerate/PyTorch/quantization.py @@ -0,0 +1,158 @@ +from typing import cast + +import torch + +from . import nn + +Tensor = torch.Tensor + + +# based on ComfyUI's and MinusZoneAI's fp8_linear optimization +def fp8_linear_forward(cls: nn.Linear, input: Tensor): + weight_dtype = cls.weight.dtype + base_dtype = input.dtype + if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if len(input.shape) == 3: + input_shape = input.shape + + scale_weight: Tensor | None = getattr(cls, "scale_weight", None) + if scale_weight is None: + scale_weight = torch.ones((), device=input.device, dtype=torch.float32) + else: + scale_weight = scale_weight.to(input.device).squeeze() + + scale_input = torch.ones((), device=input.device, dtype=torch.float32) + + input = torch.clamp(input, min=-448, max=448, out=input) + inn = ( + input.reshape(-1, input_shape[2]).to(torch.float8_e4m3fn).contiguous() + ) # always e4m3fn because e5m2 * e5m2 is not supported + + bias = cls.bias if cls.bias is not None else None + + o = torch._scaled_mm( + inn, cls.weight.t(), out_dtype=base_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight + ) + + return o.reshape((-1, input_shape[1], cls.weight.shape[0])) + else: + raise + else: + raise + + +def convert_fp8_linear( + module: nn.Module, +): + apply_fn = fp8_linear_forward + + for _, sub in list(module.named_modules()): + if isinstance(sub, nn.Linear): + if getattr(sub, "_fp8", False): + continue + setattr(sub, "forward", apply_fn) + setattr(sub, "_fp8", True) + return module + + +def per_tensor_quantize(tensor: torch.Tensor) -> tuple[Tensor, Tensor]: + """Quantize a tensor using per-tensor static scaling factor. + Args: + tensor: The input tensor. + """ + finfo = torch.finfo(torch.float8_e4m3fn) + # Calculate the scale as dtype max divided by absmax. + # Since .abs() creates a new tensor, we use aminmax to get + # the min and max first and then calculate the absmax. + if tensor.numel() == 0: + # Deal with empty tensors (triggered by empty MoE experts) + min_val, max_val = ( + torch.tensor(0.0, dtype=tensor.dtype), + torch.tensor(1.0, dtype=tensor.dtype), + ) + else: + min_val, max_val = tensor.aminmax() + amax = min_val.abs().max(max_val.abs()) + scale = finfo.max / amax.clamp(min=1e-12) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max) + # Return both float8 data and the inverse scale (as float), + # as both required as inputs to torch._scaled_mm + qweight = qweight.to(torch.float8_e4m3fn) + scale = scale.float().reciprocal() + return qweight, scale + + +def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): + cuda_compute_capability = torch.cuda.get_device_capability() + if cuda_compute_capability >= (9, 0): + output, _ = torch._scaled_mm( + A, + B.t(), + out_dtype=out_dtype, + scale_a=A_scale, + scale_b=B_scale, + bias=bias, + ) + else: + output = torch.nn.functional.linear( + A.to(out_dtype) * A_scale, + B.to(out_dtype) * B_scale.to(out_dtype), + bias=bias, + ) + return output + + +class FP8DynamicLinear(nn.Module): + def __init__(self, qweight: Tensor, scale: Tensor, bias: Tensor): + super().__init__() + self.weight = torch.nn.Parameter(qweight, requires_grad=False) + self.weight_scale = torch.nn.Parameter(scale, requires_grad=False) + self.bias = bias + + def __call__(self, x): + qinput, x_scale = per_tensor_quantize(x) + output = fp8_gemm( + A=qinput, + A_scale=x_scale, + B=self.weight, + B_scale=self.weight_scale, + bias=self.bias, + out_dtype=x.dtype, + ) + return output + + +def replace_all_linear_with_fp8(model: nn.Module): + """ + Recursively replace every nn.Linear with FP8DynamicLinear in-place. + """ + + def _recursively_replace(parent: nn.Module): + for child_name, child in list(parent.named_children()): + child = cast(nn.Module, child) + if isinstance(child, FP8DynamicLinear): + continue + + if isinstance(child, nn.Linear): + device = child.weight.device + + w = child.weight + + b = child.bias.clone() + + qw, qs = per_tensor_quantize(w) + + quant_linear = FP8DynamicLinear(qw, qs, b) + + quant_linear.to(device) + + setattr(parent, child_name, quant_linear) + + del child + else: + _recursively_replace(child) + + _recursively_replace(model) diff --git a/GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py b/GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py index 0b9eec0c..06bdfa7c 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py +++ b/GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py @@ -1,20 +1,79 @@ -from typing import Protocol +from typing import Callable, Protocol, TypeVar, cast import torch import torch.nn.functional as F +from typing_extensions import ParamSpec +P = ParamSpec("P") +R = TypeVar("R") Tensor = torch.Tensor +def script(fn: Callable[P, R]) -> Callable[P, R]: + scripted = torch.jit.script(fn) + return cast(Callable[P, R], scripted) + + +@script +def apply_repetition_penalty(logits: Tensor, previous_tokens: Tensor, repetition_penalty: float): + previous_tokens = previous_tokens.long() + score = torch.gather(logits, dim=1, index=previous_tokens) + score = torch.where( + score < 0, + score * repetition_penalty, + score / repetition_penalty, + ) + logits.scatter_(dim=1, index=previous_tokens, src=score) + return logits + + +@script +def apply_greedy_sampling(logits: Tensor): + return torch.argmax(logits, dim=-1, keepdim=True).to(dtype=torch.int32) + + +@script +def apply_temperature(logits: Tensor, temperature: float): + return logits / temperature + + +@script +def apply_top_k(logits: Tensor, top_k: int): + v, _ = torch.topk(logits, top_k) + pivot = v[:, -1].unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + return logits + + +@script +def apply_top_p(logits: Tensor, top_p: float): + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + cum_probs[cum_probs > 1] = 1 + sorted_indices_to_remove = cum_probs > top_p + sorted_indices_to_remove[:, 0] = False # keep at least one option + indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) + logits = logits.masked_fill(indices_to_remove, -float("Inf")) + return logits + + +@script +def apply_sampling(logits: Tensor): + probs = F.softmax(logits, dim=-1) + q = -torch.log(torch.rand_like(probs)) + idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32) + return idx_next + + class SampleProtocol(Protocol): @staticmethod def __call__( logits: Tensor, previous_tokens: Tensor, + repetition_penalty: float, temperature: float, top_k: int, top_p: float, - repetition_penalty: float, ) -> Tensor: ... @@ -23,45 +82,23 @@ class sample_naive(SampleProtocol): def __call__( logits: Tensor, previous_tokens: Tensor, - temperature: float, - top_k: int, - top_p: float, - repetition_penalty: float, + repetition_penalty: float = 1.35, + temperature: float = 1.0, + top_k: int = 15, + top_p: float = 1.0, ): - if temperature <= 1e-5: - probs = F.softmax(logits, dim=-1) - return torch.argmax(probs, dim=-1, keepdim=True).to(dtype=torch.int32) - if repetition_penalty != 1.0: - previous_tokens = previous_tokens.long() - score = torch.gather(logits, dim=1, index=previous_tokens) - score = torch.where( - score < 0, - score * repetition_penalty, - score / repetition_penalty, - ) - logits.scatter_(dim=1, index=previous_tokens, src=score) + logits = apply_repetition_penalty(logits, previous_tokens, repetition_penalty) + + if temperature <= 1e-5: + return apply_greedy_sampling(logits) + elif temperature < 1.0: + logits = apply_temperature(logits, temperature) + + if top_k < 1025: + logits = apply_top_k(logits, top_k) if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - cum_probs[cum_probs > 1] = 1 - sorted_indices_to_remove = cum_probs > top_p - sorted_indices_to_remove[:, 0] = False # keep at least one option - indices_to_remove = sorted_indices_to_remove.scatter( - dim=1, index=sorted_indices, src=sorted_indices_to_remove - ) - logits = logits.masked_fill(indices_to_remove, -float("Inf")) + logits = apply_top_p(logits, top_p) - if temperature < 1.0: - logits /= temperature - - v, _ = torch.topk(logits, top_k) - pivot = v[:, -1].unsqueeze(-1) - logits = torch.where(logits < pivot, -float("Inf"), logits) - - probs = F.softmax(logits, dim=-1) - q = -torch.log(torch.rand_like(probs)) - idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32) - - return idx_next + return apply_sampling(logits) diff --git a/GPT_SoVITS/Accelerate/PyTorch/structs.py b/GPT_SoVITS/Accelerate/PyTorch/structs.py index 1822acdc..fe11a501 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/structs.py +++ b/GPT_SoVITS/Accelerate/PyTorch/structs.py @@ -18,6 +18,7 @@ Tensor = torch.Tensor class T2SResult: result: list[Tensor] | None = None infer_speed: tuple[float, float] = (0.0, 0.0) + total_tokens: int = 0 status: Literal["Success", "Error"] = "Success" exception: Optional[Exception] = None traceback: Optional[str] = None @@ -66,7 +67,7 @@ class T2SDecoderProtocol(Protocol): class T2SEngineProtocol(Protocol): - def _handle_request(self, request: T2SRequest) -> tuple[list[Tensor], float, float]: ... + def _handle_request(self, request: T2SRequest) -> tuple[list[Tensor], float, float, int]: ... def generate(self, request: T2SRequest) -> T2SResult: ... @@ -107,6 +108,7 @@ class T2SSession: self.input_pos = torch.zeros_like(self.prefill_len) self.input_pos.add_(self.prefill_len) + self.input_pos.squeeze_(0) # CUDA Graph self.stream: Optional[torch.cuda.Stream] = None diff --git a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py index ada5b096..89f18bc8 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py +++ b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py @@ -1,15 +1,15 @@ import contextlib -import gc import os import sys import time import traceback from importlib import import_module +from typing import Literal import torch from rich.progress import BarColumn, Progress, TextColumn -from ..logger import SpeedColumnToken, console, logger +from ..logger import SpeedColumnToken, console, logger, timer from .structs import T2SEngineProtocol, T2SRequest, T2SResult, T2SSession from .t2s_model_abc import ( CUDAGraphCacheABC, @@ -41,12 +41,14 @@ class T2SEngine(T2SEngineProtocol): decoder = self.decoder_model session = T2SSession(decoder, request, device=self.device, dtype=self.dtype) batch_idx = torch.arange(session.bsz) + debug = request.debug t1 = 0.0 infer_speed = 0.0 infer_time = 0.0 + idx = 0 - torch_profiler = TorchProfiler(request.debug) + torch_profiler = TorchProfiler(debug) with ( torch_profiler.profiler(), Progress( @@ -59,14 +61,15 @@ class T2SEngine(T2SEngineProtocol): ) as progress, ): torch_profiler.start() - max_token = int(min(2000 - session.input_pos.max(), 1500)) + max_token = min(int(1500 - session.input_pos.max()), 1000) task = progress.add_task("T2S Decoding", total=max_token) for idx in range(max_token): progress.update(task, advance=1) if idx == 0: - with torch_profiler.record("Prefill"): + with torch_profiler.record("Prefill"), timer("Torch.Prefill", debug=debug): session.kv_cache = decoder.init_cache(session.bsz) + t1 = time.perf_counter() xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask) xy_dec = xy_dec[None, batch_idx, session.input_pos - 1] else: @@ -78,7 +81,7 @@ class T2SEngine(T2SEngineProtocol): ): self.graphcache.assign_graph(session) - with torch_profiler.record("Decode"): + with torch_profiler.record("Decode"), timer("Torch.Decode", debug=debug): if session.graph: assert session.stream session.stream.wait_stream(torch.cuda.default_stream()) @@ -103,7 +106,7 @@ class T2SEngine(T2SEngineProtocol): if idx == 0: logits[:, -1] = float("-inf") - with torch_profiler.record("Sampling"): + with torch_profiler.record("Sampling"), timer("Torch.Sampling", debug=debug): samples = session.sample( logits=logits, previous_tokens=session.y[:, : session.y_len + idx], @@ -115,7 +118,7 @@ class T2SEngine(T2SEngineProtocol): session.y[batch_idx, session.y_len + idx] = samples session.input_pos.add_(1) - with torch_profiler.record("EOS"): + with torch_profiler.record("EOS"), timer("Torch.EOS", debug=debug): argmax_token = torch.argmax(logits, dim=-1) sample_token = samples.squeeze(1) EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS) @@ -128,73 +131,48 @@ class T2SEngine(T2SEngineProtocol): session.y_results[i] = session.y[i, session.y_len : session.y_len + idx] session.completed[newly_done_indices] = True - if torch.all(session.completed).item(): - if session.y[:, session.y_len :].sum() == 0: - session.y_results = [torch.tensor(0) for _ in range(session.bsz)] - logger.error("Bad Zero Prediction") - else: - logger.info( - f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}" - ) - logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s") - infer_time = time.perf_counter() - t1 - infer_speed = (idx - 1) / infer_time - break + if torch.all(session.completed).item(): + logger.info( + f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}" + ) + logger.info( + f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s" + ) + infer_time = time.perf_counter() - t1 + infer_speed = (idx + 1) * session.bsz / infer_time + break - if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1: - for i in range(session.bsz): - if not session.completed[i].item(): - session.y_results[i] = session.y[i, session.y_len : session.y_len + 1499] - session.completed[i] = True - logger.error("Bad Full Prediction") - break + if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1: + for i in range(session.bsz): + if not session.completed[i].item(): + session.y_results[i] = session.y[[i], session.y_len : session.y_len + idx] + session.completed[i] = True + logger.error("Bad Full Prediction") + infer_time = time.perf_counter() - t1 + infer_speed = (idx + 1) * session.bsz / infer_time + break - with torch_profiler.record("NextPos"): + with torch_profiler.record("NextPos"), timer("Torch.NextPos", debug=debug): y_emb = decoder.ar_audio_embedding(samples) session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb) - if idx == 1: - t1 = time.perf_counter() - - if idx == 20: + if idx == 10: torch_profiler.end() - if idx % 100 == 0: - match session.device.type: - case "cuda": - torch.cuda.empty_cache() - case "mps": - torch.mps.empty_cache() - case "xpu": - torch.xpu.empty_cache() - case "mtia": - torch.mtia.empty_cache() - case "cpu": - pass - - match session.device.type: - case "cuda": - if session.stream is not None: - torch.cuda.current_stream().wait_stream(session.stream) - torch.cuda.empty_cache() - case "mps": - torch.mps.empty_cache() - case "xpu": - torch.xpu.empty_cache() - case "mtia": - torch.mtia.empty_cache() - case "cpu": - gc.collect(1) - if request.use_cuda_graph and self.graphcache.is_applicable: self.graphcache.release_graph(session) - return session.y_results[: request.valid_length], infer_speed, infer_time + return session.y_results[: request.valid_length], infer_speed, infer_time, (idx + 1) * session.bsz def generate(self, request: T2SRequest): try: - result, infer_speed, infer_time = self._handle_request(request) - t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success") + result, infer_speed, infer_time, total_tokens = self._handle_request(request) + t2s_result = T2SResult( + result=result, + infer_speed=(infer_speed, infer_time), + total_tokens=total_tokens, + status="Success", + ) except Exception as e: t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc()) if self.decoder_model.compiled: @@ -203,7 +181,12 @@ class T2SEngine(T2SEngineProtocol): return t2s_result @staticmethod - def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "Flash-Attn-Varlen-CUDAGraph"): + def load_decoder( + weights_path: os.PathLike, + max_batch_size: int = 1, + backend: str = "Flash-Attn-Varlen-CUDAGraph", + quantize_mode: Literal["Int8", "FP8", "FP8_E4M3FN"] | None = None, + ) -> T2SDecoderABC: logger.info(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend") module_path = f".backends.{backend.lower().replace('-', '_').replace('cudagraph', 'cuda_graph')}" decoder_cls_name = "T2SDecoder" @@ -215,6 +198,10 @@ class T2SEngine(T2SEngineProtocol): state_dict = dict_s1["weight"] decoder.load_state_dict(state_dict) + if quantize_mode is not None: + decoder.quantize(quantize_mode) + logger.info(f"Quantized by {quantize_mode} Quantization") + return decoder.eval() def init_cache(self): diff --git a/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py b/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py index 7b46b6ab..6c8b51f5 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py +++ b/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py @@ -13,7 +13,7 @@ import time from abc import ABC, abstractmethod from contextlib import nullcontext from pathlib import Path -from typing import MutableSequence +from typing import Literal, MutableSequence import torch import torch._inductor.config @@ -24,6 +24,7 @@ from torch.profiler import ExecutionTraceObserver, ProfilerAction, tensorboard_t from tools.my_utils import get_machine_id from . import nn +from .quantization import replace_all_linear_with_fp8 from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession Tensor = torch.Tensor @@ -61,26 +62,26 @@ class SinePositionalEmbedding(nn.Module): scale: bool = False, alpha: bool = False, max_batch_size: int = 10, - max_seq_len: int = 2000, + max_seq_length: int = 1500, ): super().__init__() self.embedding_dim = embedding_dim self.x_scale = math.sqrt(embedding_dim) if scale else 1.0 self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) self.max_batch_size = max_batch_size - self.max_seq_len = max_seq_len + self.max_seq_length = max_seq_length self.reverse = False - self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False) + self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_length, embedding_dim), persistent=False) self.pe: torch.Tensor self.compute_pe() def compute_pe(self): """Reset the positional encodings.""" if self.reverse: - position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1) + position = torch.arange(self.max_seq_length - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1) else: - position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1) + position = torch.arange(self.max_seq_length, dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim) ) @@ -423,7 +424,7 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol): def __init__( self, config: dict, - max_seq_length: int = 2000, + max_seq_length: int = 1500, max_batch_size: int = 10, ) -> None: super().__init__() @@ -467,7 +468,7 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol): scale=False, alpha=True, max_batch_size=max_batch_size, - max_seq_len=max_seq_length, + max_seq_length=max_seq_length, ) self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size) self.ar_audio_position = SinePositionalEmbedding( @@ -475,9 +476,12 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol): scale=False, alpha=True, max_batch_size=max_batch_size, - max_seq_len=max_seq_length, + max_seq_length=max_seq_length, ) + self.bits: int + self.group_size: int + self._register_load_state_dict_pre_hook(self.load_hook) def load_hook(self, state_dict: dict[str, Tensor], prefix, *args): @@ -608,6 +612,32 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol): def post_forward(self, idx: int, session: T2SSession) -> None: return + def quantize(self, mode: Literal["Int8", "FP8", "FP8_E4M3FN"] | None = None) -> None: + if mode is None: + return + if mode not in {"Int8", "FP8", "FP8_E4M3FN"}: + raise ValueError(f"Unsupported quantization mode: {mode}") + match mode: + case "Int8": + self.bits = 8 + self.group_size = 32 + import torchao + + torchao.quantization.quantize_(self.h, torchao.quantization.Int8WeightOnlyConfig(self.group_size)) + + case "FP8": + self.bits = 8 + import torchao + + torchao.quantization.quantize_(self.h, torchao.quantization.Float8WeightOnlyConfig()) + + case "FP8_E4M3FN": + self.bits = 8 + replace_all_linear_with_fp8(self.h) + + case _: + raise ValueError(f"Unsupported Quantization Mode for PyTorch: {mode}") + class CUDAGraphCacheABC(ABC): def __init__( @@ -662,13 +692,13 @@ class CUDAGraphCacheABC(ABC): class TorchProfiler: - def __init__(self, debug: bool, log_dir: str = "./profiler") -> None: - self.debug = debug - self.log_dir = log_dir + str(time.time()) + def __init__(self, debug: bool, log_dir: str = "./profiler/torch") -> None: + self.debug = debug and os.environ.get("TORCH_PROFILER") == "1" + self.log_dir = log_dir + "/" + str(time.time()) self.__profiler: torch.profiler.profile if self.debug and not os.path.exists(self.log_dir): - os.makedirs(self.log_dir) + os.makedirs(self.log_dir, exist_ok=True) self.tensorboard_handler = tensorboard_trace_handler(self.log_dir) diff --git a/GPT_SoVITS/Accelerate/__init__.py b/GPT_SoVITS/Accelerate/__init__.py index 797fe1d0..5508f6d0 100644 --- a/GPT_SoVITS/Accelerate/__init__.py +++ b/GPT_SoVITS/Accelerate/__init__.py @@ -1,18 +1,13 @@ from . import MLX, PyTorch from .logger import console, logger, tb -from .PyTorch import T2SEngineTorch, T2SRequest, T2SResult +from .MLX import quantization_methods_mlx +from .PyTorch import T2SEngineTorch, T2SRequest, T2SResult, quantization_methods_torch from .PyTorch.structs import T2SEngineProtocol backends = PyTorch.backends + MLX.backends backends = [ - b.replace("_", "-") - .title() - .replace("Mlx", "MLX") - .replace("Mps", "MPS") - .replace("Cuda", "CUDA") - .replace("Mxfp4", "MXFP4") - for b in backends + b.replace("_", "-").title().replace("Mlx", "MLX").replace("Mps", "MPS").replace("Cuda", "CUDA") for b in backends ] @@ -27,4 +22,6 @@ __all__ = [ "console", "tb", "T2SEngineProtocol", + "quantization_methods_torch", + "quantization_methods_mlx", ] diff --git a/GPT_SoVITS/Accelerate/logger.py b/GPT_SoVITS/Accelerate/logger.py index 021e72e3..bde6f230 100644 --- a/GPT_SoVITS/Accelerate/logger.py +++ b/GPT_SoVITS/Accelerate/logger.py @@ -1,4 +1,7 @@ import sys +import time +from collections import defaultdict +from contextlib import nullcontext from typing import Optional from loguru import logger @@ -201,3 +204,46 @@ if __name__ == "__main__": raise RuntimeError() except Exception: logger.bind(show_locals=False).exception("TEST") + + +class Timer: + def __init__(self): + self.records: dict[str, list[float]] = defaultdict(list) + self._stack: list[tuple[str, int]] = [] + + def __call__(self, category: str, debug=False): + timer = self + + class _Ctx: + def __enter__(self): + timer._stack.append((category, time.perf_counter_ns())) + return timer # 如需在with块里调用timer方法 + + def __exit__(self, exc_type, exc_val, exc_tb): + end = time.perf_counter_ns() + if not timer._stack: + raise RuntimeError("Timer stack underflow: __exit__ without matching __enter__") + cat, start = timer._stack.pop() + if cat != category: + raise RuntimeError(f"Mismatched timer context: expected '{cat}', got '{category}'") + elapsed_sec = (end - start) / 1e9 + timer.records[cat].append(elapsed_sec) + return False + + if debug: + return _Ctx() + else: + return nullcontext() + + def clear(self): + self.records.clear() + self._stack.clear() + + def summary(self): + for cat, times in self.records.items(): + total = sum(times) + avg = total / len(times) if times else 0.0 + print(f"{cat}: count={len(times)}, total={total:.6f}s, avg={avg:.6f}s") + + +timer = Timer() diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index b5150c3f..904affbb 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -51,13 +51,11 @@ warnings.filterwarnings( ) warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*") -logging.getLogger("markdown_it").setLevel(logging.ERROR) logging.getLogger("urllib3").setLevel(logging.ERROR) logging.getLogger("httpcore").setLevel(logging.ERROR) logging.getLogger("httpx").setLevel(logging.ERROR) logging.getLogger("asyncio").setLevel(logging.ERROR) logging.getLogger("charset_normalizer").setLevel(logging.ERROR) -logging.getLogger("torchaudio._extension").setLevel(logging.ERROR) logging.getLogger("multipart.multipart").setLevel(logging.ERROR) os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -88,6 +86,12 @@ def lang_type(text: str) -> str: return "Auto" +def none_or_str(value: str): + if value == "None": + return None + return value + + def build_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser( prog="inference_webui", @@ -108,6 +112,15 @@ def build_parser() -> argparse.ArgumentParser: help="AR Inference Backend", required=False, ) + p.add_argument( + "--quantization", + "-q", + default="None", + choices=MLX.quantization_methods_mlx + PyTorch.quantization_methods_torch, + type=none_or_str, + help="Quantization Method", + required=False, + ) p.add_argument( "--device", "-d", @@ -393,10 +406,9 @@ with contextlib.suppress(UnboundLocalError): def change_gpt_weights(gpt_path): global t2s_engine, config - if "mlx" in ar_backend.lower(): t2s_engine = MLX.T2SEngineMLX( - MLX.T2SEngineMLX.load_decoder(Path(gpt_path), backend=ar_backend), + MLX.T2SEngineMLX.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization), "mx.gpu" if infer_device.type != "cpu" else "mx.cpu", dtype=dtype, ) @@ -404,7 +416,7 @@ def change_gpt_weights(gpt_path): total = sum((p[-1].size for p in mxutils.tree_flatten(t2s_engine.decoder_model.parameters()))) # type: ignore else: t2s_engine = PyTorch.T2SEngineTorch( - PyTorch.T2SEngineTorch.load_decoder(Path(gpt_path), backend=ar_backend), + PyTorch.T2SEngineTorch.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization), device, dtype=dtype, ) @@ -824,7 +836,7 @@ def get_tts_wav( pred_semantic_list = t2s_result.result assert pred_semantic_list, t2s_result.traceback pred_semantic = pred_semantic_list[0].unsqueeze(0).to(infer_device) - infer_len.append(pred_semantic.shape[-1]) + infer_len.append(t2s_result.total_tokens) infer_time.append(t2s_result.infer_speed[-1]) cache[i_text] = pred_semantic diff --git a/GPT_SoVITS/text/chinese2.py b/GPT_SoVITS/text/chinese2.py index 7ec03e77..ec1e0dd2 100644 --- a/GPT_SoVITS/text/chinese2.py +++ b/GPT_SoVITS/text/chinese2.py @@ -30,6 +30,7 @@ pinyin_to_symbol_map = { parent_directory = os.path.dirname(current_file_path) is_g2pw = os.getenv("G2PW", "1") == "1" +debug = os.getenv("DEBUG", "0") == "1" if is_g2pw: g2pw = G2PWPinyin( model_dir="GPT_SoVITS/text/G2PWModel", @@ -202,7 +203,8 @@ def _g2p(segments): # assert len(sub_initials) == len(sub_finals) == len(word) initials = sum(initials, []) finals = sum(finals, []) - print("pypinyin结果", initials, finals) + if debug: + print("pypinyin结果", initials, finals) else: # g2pw采用整句推理 pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3) diff --git a/README.md b/README.md index 33c442e9..558c1edf 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.

[![GitHub release](https://img.shields.io/github/v/release/RVC-Boss/gpt-sovits?style=for-the-badge&logo=github)](https://github.com/RVC-Boss/gpt-sovits/releases) [![Train In Colab](https://img.shields.io/badge/Colab-Training-F9AB00?style=for-the-badge&logo=googlecolab)](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb) -[![Huggingface](https://img.shields.io/badge/免费在线体验-free_online_demo-yellow.svg?style=for-the-badge&logo=huggingface)](https://lj1995-gpt-sovits-proplus.hf.space/) +[![Huggingface](https://img.shields.io/badge/在线体验-free_online_demo-yellow.svg?style=for-the-badge&logo=huggingface)](https://lj1995-gpt-sovits-proplus.hf.space/) [![Image Size](https://img.shields.io/docker/image-size/xxxxrt666/gpt-sovits/latest?style=for-the-badge&logo=docker)](https://hub.docker.com/r/xxxxrt666/gpt-sovits) [![简体中文](https://img.shields.io/badge/简体中文-阅读文档-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e) @@ -57,7 +57,7 @@ Unseen speakers few-shot fine-tuning demo: | RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph | | RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph | | RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph | -| Apple M4 | 0.21 | UNK | 1 | MLX Quantized Affined | +| Apple M4 | 0.16 | UNK | 1 | MLX Varlen | diff --git a/docs/cn/README.md b/docs/cn/README.md index 2f03b3ca..12e746d6 100644 --- a/docs/cn/README.md +++ b/docs/cn/README.md @@ -13,7 +13,7 @@ [![GitHub release](https://img.shields.io/github/v/release/RVC-Boss/gpt-sovits?style=for-the-badge&logo=github)](https://github.com/RVC-Boss/gpt-sovits/releases) [![Train In Colab](https://img.shields.io/badge/Colab-Training-F9AB00?style=for-the-badge&logo=googlecolab)](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb) -[![Huggingface](https://img.shields.io/badge/免费在线体验-free_online_demo-yellow.svg?style=for-the-badge&logo=huggingface)](https://lj1995-gpt-sovits-proplus.hf.space/) +[![Huggingface](https://img.shields.io/badge/在线体验-free_online_demo-yellow.svg?style=for-the-badge&logo=huggingface)](https://lj1995-gpt-sovits-proplus.hf.space/) [![Image Size](https://img.shields.io/docker/image-size/xxxxrt666/gpt-sovits/latest?style=for-the-badge&logo=docker)](https://hub.docker.com/r/xxxxrt666/gpt-sovits) [![简体中文](https://img.shields.io/badge/简体中文-阅读文档-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e) @@ -57,7 +57,7 @@ | RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph | | RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph | | RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph | -| Apple M4 | 0.21 | UNK | 1 | MLX Quantized Affined | +| Apple M4 | 0.16 | UNK | 1 | MLX Varlen | diff --git a/docs/ja/README.md b/docs/ja/README.md index c41f6aed..4f064692 100644 --- a/docs/ja/README.md +++ b/docs/ja/README.md @@ -13,7 +13,7 @@ [![GitHub release](https://img.shields.io/github/v/release/RVC-Boss/gpt-sovits?style=for-the-badge&logo=github)](https://github.com/RVC-Boss/gpt-sovits/releases) [![Train In Colab](https://img.shields.io/badge/Colab-Training-F9AB00?style=for-the-badge&logo=googlecolab)](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb) -[![Huggingface](https://img.shields.io/badge/免费在线体验-free_online_demo-yellow.svg?style=for-the-badge&logo=huggingface)](https://lj1995-gpt-sovits-proplus.hf.space/) +[![Huggingface](https://img.shields.io/badge在线体验-free_online_demo-yellow.svg?style=for-the-badge&logo=huggingface)](https://lj1995-gpt-sovits-proplus.hf.space/) [![Image Size](https://img.shields.io/docker/image-size/xxxxrt666/gpt-sovits/latest?style=for-the-badge&logo=docker)](https://hub.docker.com/r/xxxxrt666/gpt-sovits) [![简体中文](https://img.shields.io/badge/简体中文-阅读文档-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e) @@ -57,7 +57,7 @@ | RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph | | RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph | | RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph | -| Apple M4 | 0.21 | UNK | 1 | MLX Quantized Affined | +| Apple M4 | 0.16 | UNK | 1 | MLX Varlen | diff --git a/docs/ko/README.md b/docs/ko/README.md index 28012977..62b09db1 100644 --- a/docs/ko/README.md +++ b/docs/ko/README.md @@ -13,7 +13,7 @@ [![GitHub release](https://img.shields.io/github/v/release/RVC-Boss/gpt-sovits?style=for-the-badge&logo=github)](https://github.com/RVC-Boss/gpt-sovits/releases) [![Train In Colab](https://img.shields.io/badge/Colab-Training-F9AB00?style=for-the-badge&logo=googlecolab)](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb) -[![Huggingface](https://img.shields.io/badge/免费在线体验-free_online_demo-yellow.svg?style=for-the-badge&logo=huggingface)](https://lj1995-gpt-sovits-proplus.hf.space/) +[![Huggingface](https://img.shields.io/badge/在线体验-free_online_demo-yellow.svg?style=for-the-badge&logo=huggingface)](https://lj1995-gpt-sovits-proplus.hf.space/) [![Image Size](https://img.shields.io/docker/image-size/xxxxrt666/gpt-sovits/latest?style=for-the-badge&logo=docker)](https://hub.docker.com/r/xxxxrt666/gpt-sovits) [![简体中文](https://img.shields.io/badge/简体中文-阅读文档-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e) @@ -57,7 +57,7 @@ | RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph | | RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph | | RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph | -| Apple M4 | 0.21 | UNK | 1 | MLX Quantized Affined | +| Apple M4 | 0.16 | UNK | 1 | MLX Varlen | diff --git a/docs/tr/README.md b/docs/tr/README.md index de94d5bd..8bad35cb 100644 --- a/docs/tr/README.md +++ b/docs/tr/README.md @@ -13,7 +13,7 @@ Güçlü Birkaç Örnekli Ses Dönüştürme ve Metinden Konuşmaya Web Arayüz [![GitHub release](https://img.shields.io/github/v/release/RVC-Boss/gpt-sovits?style=for-the-badge&logo=github)](https://github.com/RVC-Boss/gpt-sovits/releases) [![Train In Colab](https://img.shields.io/badge/Colab-Training-F9AB00?style=for-the-badge&logo=googlecolab)](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb) -[![Huggingface](https://img.shields.io/badge/免费在线体验-free_online_demo-yellow.svg?style=for-the-badge&logo=huggingface)](https://lj1995-gpt-sovits-proplus.hf.space/) +[![Huggingface](https://img.shields.io/badge/在线体验-free_online_demo-yellow.svg?style=for-the-badge&logo=huggingface)](https://lj1995-gpt-sovits-proplus.hf.space/) [![Image Size](https://img.shields.io/docker/image-size/xxxxrt666/gpt-sovits/latest?style=for-the-badge&logo=docker)](https://hub.docker.com/r/xxxxrt666/gpt-sovits) [![简体中文](https://img.shields.io/badge/简体中文-阅读文档-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e) @@ -57,7 +57,7 @@ Görünmeyen konuşmacılar birkaç örnekli ince ayar demosu: | RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph | | RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph | | RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph | -| Apple M4 | 0.21 | UNK | 1 | MLX Quantized Affined | +| Apple M4 | 0.16 | UNK | 1 | MLX Varlen | diff --git a/install.ps1 b/install.ps1 index d993e97b..4646a60e 100644 --- a/install.ps1 +++ b/install.ps1 @@ -225,7 +225,7 @@ switch ($Device) { Write-Warning "CUDA 12.8 Is Not Supported By Current Driver" } Write-Info "Installing PyTorch For CUDA 12.8..." - Invoke-Pip torch torchaudio --index-url "https://download.pytorch.org/whl/cu128" + Invoke-Pip torch torchao --index-url "https://download.pytorch.org/whl/cu128" Invoke-Conda cuda-nvcc=12.8 Invoke-Pip psutil ninja packaging wheel "setuptools>=42" Write-Info "Installing Flash Attn..." @@ -240,7 +240,7 @@ switch ($Device) { Write-Warning "CUDA 12.6 Is Not Supported By Current Driver" } Write-Info "Installing PyTorch For CUDA 12.6..." - Invoke-Pip torch torchaudio --index-url "https://download.pytorch.org/whl/cu126" + Invoke-Pip torch torchao --index-url "https://download.pytorch.org/whl/cu126" Invoke-Conda cuda-nvcc=12.6 Invoke-Pip psutil ninja packaging wheel "setuptools>=42" Write-Info "Installing Flash Attn..." @@ -249,7 +249,7 @@ switch ($Device) { } "CPU" { Write-Info "Installing PyTorch For CPU..." - Invoke-Pip torch torchaudio --index-url "https://download.pytorch.org/whl/cpu" + Invoke-Pip torch torchao --index-url "https://download.pytorch.org/whl/cpu" } } Write-Success "PyTorch Installed" diff --git a/install.sh b/install.sh index a5cce698..e7d298af 100644 --- a/install.sh +++ b/install.sh @@ -334,14 +334,14 @@ if [ "$USE_CUDA" = true ] && [ "$WORKFLOW" = false ]; then echo -r "${WARNING}CUDA 12.8 Is Not Supported By Current Driver" fi echo -e "${INFO}Installing PyTorch For CUDA 12.8..." - run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cu128" + run_pip_quiet torch torchao --index-url "https://download.pytorch.org/whl/cu128" run_conda_quiet cuda-nvcc=12.8 elif [ "$CUDA" = 126 ]; then if awk "BEGIN {exit !($CUDAVERSION < 12.6)}"; then echo -r "${WARNING}CUDA 12.6 Is Not Supported By Current Driver" fi echo -e "${INFO}Installing PyTorch For CUDA 12.6..." - run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cu126" + run_pip_quiet torch torchao --index-url "https://download.pytorch.org/whl/cu126" run_conda_quiet cuda-nvcc=12.6 fi echo -e "${INFO}Installing Flash Attn" @@ -350,14 +350,14 @@ if [ "$USE_CUDA" = true ] && [ "$WORKFLOW" = false ]; then echo -e "${SUCCESS}Flash Attn Installed" elif [ "$USE_MLX" = true ] && [ "$WORKFLOW" = false ]; then echo -e "${INFO}Installing MLX & PyTorch For MPS..." - run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cpu" + run_pip_quiet torch torchao --index-url "https://download.pytorch.org/whl/cpu" run_pip_quiet mlx elif [ "$USE_ROCM" = true ] && [ "$WORKFLOW" = false ]; then echo -e "${INFO}Installing PyTorch For ROCm 6.2..." - run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/rocm6.2" + run_pip_quiet torch torchao --index-url "https://download.pytorch.org/whl/rocm6.2" elif [ "$USE_CPU" = true ] && [ "$WORKFLOW" = false ]; then echo -e "${INFO}Installing PyTorch For CPU..." - run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cpu" + run_pip_quiet torch torchao --index-url "https://download.pytorch.org/whl/cpu" elif [ "$WORKFLOW" = false ]; then echo -e "${ERROR}Unknown Err" exit 1 diff --git a/requirements.txt b/requirements.txt index 0875ffa0..37ab0bce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ peft py-cpuinfo pypinyin split-lang +torchao torchaudio torchcodec transformers diff --git a/test.py b/test.py new file mode 100644 index 00000000..fd4b7127 --- /dev/null +++ b/test.py @@ -0,0 +1,879 @@ +import argparse +import contextlib +import gc +import logging +import os +import re +import traceback +import warnings +from functools import partial +from pathlib import Path +from time import perf_counter as ttime +from typing import Any + +import gradio as gr +import librosa +import numpy as np +import torch +import torchaudio +from transformers import AutoModelForMaskedLM, AutoTokenizer + +from config import ( + change_choices, + get_dtype, + get_weights_names, + pretrained_sovits_name, +) +from config import ( + infer_device as default_device, +) +from GPT_SoVITS.Accelerate import MLX, PyTorch, T2SEngineProtocol, T2SRequest, backends +from GPT_SoVITS.Accelerate.logger import console, timer +from GPT_SoVITS.feature_extractor import cnhubert +from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch +from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3 +from GPT_SoVITS.process_ckpt import inspect_version +from GPT_SoVITS.text import cleaned_text_to_sequence +from GPT_SoVITS.text.cleaner import clean_text +from GPT_SoVITS.text.LangSegmenter import LangSegmenter +from tools.i18n.i18n import I18nAuto, scan_language_list +from tools.my_utils import DictToAttrRecursive + +warnings.filterwarnings( + "ignore", message="MPS: The constant padding of more than 3 dimensions is not currently supported natively." +) +warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*") + +logging.getLogger("urllib3").setLevel(logging.ERROR) +logging.getLogger("httpcore").setLevel(logging.ERROR) +logging.getLogger("httpx").setLevel(logging.ERROR) +logging.getLogger("asyncio").setLevel(logging.ERROR) +logging.getLogger("charset_normalizer").setLevel(logging.ERROR) +logging.getLogger("multipart.multipart").setLevel(logging.ERROR) + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + + +_LANG_RE = re.compile(r"^[a-z]{2}[_-][A-Z]{2}$") + + +def lang_type(text: str) -> str: + if text == "Auto": + return text + if not _LANG_RE.match(text): + raise argparse.ArgumentTypeError(f"Unspported Format: {text}, Expected ll_CC/ll-CC") + ll, cc = re.split(r"[_-]", text) + language = f"{ll}_{cc}" + if language in scan_language_list(): + return language + else: + return "Auto" + + +def none_or_str(value: str): + if value == "None": + return None + return value + + +def build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + prog="inference_webui", + description=f"python -s -m GPT_SoVITS.inference_webui zh_CN -b {backends[-1]}", + ) + p.add_argument( + "language", + nargs="?", + default="Auto", + type=lang_type, + help="Language Code, Such as zh_CN, en-US", + ) + p.add_argument( + "--backends", + "-b", + choices=backends, + default=backends[-1], + help="AR Inference Backend", + required=False, + ) + p.add_argument( + "--quantization", + "-q", + default="None", + choices=MLX.quantization_methods_mlx + PyTorch.quantization_methods_torch, + type=none_or_str, + help="Quantization Method", + required=False, + ) + p.add_argument( + "--device", + "-d", + default=str(default_device), + help="Inference Device", + required=False, + ) + p.add_argument( + "--port", + "-p", + default=9872, + type=int, + help="WebUI Binding Port", + required=False, + ) + p.add_argument( + "--share", + "-s", + default=False, + action="store_true", + help="Gradio Share Link", + required=False, + ) + p.add_argument( + "--cnhubert", + default="GPT_SoVITS/pretrained_models/chinese-hubert-base", + help="CNHuBERT Pretrain", + required=False, + ) + p.add_argument( + "--bert", + default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + help="BERT Pretrain", + required=False, + ) + p.add_argument( + "--gpt", + default="", + help="GPT Model", + required=False, + ) + p.add_argument( + "--sovits", + default="", + help="SoVITS Model", + required=False, + ) + + return p + + +args = build_parser().parse_args() + +hps: Any = None +vq_model: SynthesizerTrn | SynthesizerTrnV3 | None = None +t2s_engine: T2SEngineProtocol | None = None + +version = model_version = "v2" +path_sovits_v3 = pretrained_sovits_name["v3"] +path_sovits_v4 = pretrained_sovits_name["v4"] +is_exist_s2gv3 = os.path.exists(path_sovits_v3) +is_exist_s2gv4 = os.path.exists(path_sovits_v4) + +cnhubert_base_path = str(args.cnhubert) +bert_path = str(args.bert) +infer_ttswebui = int(args.port) +is_share = bool(args.share) + + +i18n = I18nAuto(language=args.language) +ar_backend: str = args.backends +change_choices_i18n = partial(change_choices, i18n=i18n) + +SoVITS_names, GPT_names = get_weights_names(i18n) + + +dict_language_v1 = { + i18n("中文"): "all_zh", # 全部按中文识别 + i18n("英文"): "en", # 全部按英文识别 + i18n("日文"): "all_ja", # 全部按日文识别 + i18n("中英混合"): "zh", # 按中英混合识别 + i18n("日英混合"): "ja", # 按日英混合识别 + i18n("多语种混合"): "auto", # 多语种启动切分识别语种 +} +dict_language_v2 = { + i18n("中文"): "all_zh", # 全部按中文识别 + i18n("英文"): "en", # 全部按英文识别 + i18n("日文"): "all_ja", # 全部按日文识别 + i18n("粤语"): "all_yue", # 全部按粤语识别 + i18n("韩文"): "all_ko", # 全部按韩文识别 + i18n("中英混合"): "zh", + i18n("日英混合"): "ja", + i18n("粤英混合"): "yue", + i18n("韩英混合"): "ko", + i18n("多语种混合"): "auto", # 多语种启动切分识别语种 + i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种 +} +dict_language = dict_language_v1 if version == "v1" else dict_language_v2 + +punctuation = set(["!", "?", "…", ",", ".", "-", " "]) +splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…"} +v3v4set = {"v3", "v4"} + +infer_device = torch.device(args.device) +device = infer_device if infer_device.type == "cuda" else torch.device("cpu") + +dtype = get_dtype(device.index) +is_half = dtype == torch.float16 + +tokenizer = AutoTokenizer.from_pretrained(bert_path) +bert_model = AutoModelForMaskedLM.from_pretrained(bert_path).to(infer_device, dtype) + +cnhubert.cnhubert_base_path = cnhubert_base_path +ssl_model = cnhubert.get_model().to(infer_device, dtype) + + +def mel_fn(x): + return mel_spectrogram_torch( + y=x, + n_fft=1024, + num_mels=100, + sampling_rate=24000, + hop_size=256, + win_size=1024, + fmin=0, + fmax=None, + center=False, + ) + + +gpt_path = str(args.gpt) or GPT_names[0][-1] +sovits_path = str(args.sovits) or SoVITS_names[0][-1] + + +def get_bert_feature(text, word2ph): + inputs = tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(infer_device) + res = bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + + assert len(word2ph) == len(text) + phone_level_feature = [] + for i in range(len(word2ph)): + repeat_feature = res[i].repeat(word2ph[i], 1) + phone_level_feature.append(repeat_feature) + phone_level_feature_t = torch.cat(phone_level_feature, dim=0) + return phone_level_feature_t.T + + +def change_sovits_weights(sovits_path, prompt_language=None, text_language=None): + global vq_model, hps, version, model_version, dict_language + model_version, version, is_lora, hps, dict_s2 = inspect_version(sovits_path) + is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4 + path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 + if is_lora is True and is_exist is False: + info = f"{path_sovits} SoVITS {model_version} {i18n('底模缺失,无法加载相应 LoRA 权重')}" + gr.Warning(info) + raise FileNotFoundError(info) + dict_language = dict_language_v1 if version == "v1" else dict_language_v2 + visible_sample_steps = visible_inp_refs = None + if prompt_language is not None and text_language is not None: + if prompt_language in list(dict_language.keys()): + prompt_text_update, prompt_language_update = gr.skip(), gr.update(choices=list(dict_language.keys())) + else: + prompt_text_update = gr.update(value="") + prompt_language_update = gr.update(value=i18n("中文"), choices=list(dict_language.keys())) + if text_language in list(dict_language.keys()): + text_update, text_language_update = gr.skip(), gr.skip() + else: + text_update = gr.update(value="") + text_language_update = gr.update(value=i18n("中文"), choices=list(dict_language.keys())) + + if model_version in v3v4set: + visible_sample_steps = True + visible_inp_refs = False + else: + visible_sample_steps = False + visible_inp_refs = True + yield ( + prompt_text_update, + prompt_language_update, + text_update, + text_language_update, + gr.update( + visible=visible_sample_steps, + value=32 if model_version == "v3" else 8, + choices=[4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32], + ), + gr.update(visible=visible_inp_refs), + gr.update(value=False, interactive=True if model_version not in v3v4set else False), + gr.update(visible=True if model_version == "v3" else False), + gr.update(value=i18n("模型加载中,请等待"), interactive=False), + ) + + hps = DictToAttrRecursive(hps) + hps.model.semantic_frame_rate = "25hz" + hps.model.version = model_version + if model_version not in v3v4set: + vq_model = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ) + else: + vq_model = SynthesizerTrnV3( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ).eval() + + if "pretrained" not in sovits_path: + if hasattr(vq_model, "enc_q"): + del vq_model.enc_q + + vq_model.load_state_dict(dict_s2["weight"], strict=False) + + vq_model = vq_model.to(infer_device, dtype) + + yield ( + gr.skip(), + gr.skip(), + gr.skip(), + gr.skip(), + gr.skip(), + gr.skip(), + gr.skip(), + gr.skip(), + gr.update(value=i18n("合成语音"), interactive=True), + ) + + +with contextlib.suppress(UnboundLocalError): + next(change_sovits_weights(sovits_path)) + + +def change_gpt_weights(gpt_path): + global t2s_engine, config + if "mlx" in ar_backend.lower(): + t2s_engine = MLX.T2SEngineMLX( + MLX.T2SEngineMLX.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization), + "mx.gpu" if infer_device.type != "cpu" else "mx.cpu", + dtype=dtype, + ) + # t2s_engine.decoder_model.compile() + else: + t2s_engine = PyTorch.T2SEngineTorch( + PyTorch.T2SEngineTorch.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization), + device, + dtype=dtype, + ) + # t2s_engine.decoder_model.compile() + + +change_gpt_weights(gpt_path) + +resample_transform_dict = {} + + +def resample(audio_tensor, sr0, sr1, device): + global resample_transform_dict + key = f"{sr0}-{sr1}-{device}" + if key not in resample_transform_dict: + resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device) + return resample_transform_dict[key](audio_tensor) + + +def get_spepc(hps, filename, dtype, device, is_v2pro=False): + sr1 = int(hps.data.sampling_rate) + audio, sr0 = torchaudio.load_with_torchcodec(filename) + audio = audio.to(device) + + if sr0 != sr1: + audio = resample(audio, sr0, sr1, device) + if audio.shape[0] > 1: + audio = audio.mean(0).unsqueeze(0) + + maxx = float(audio.abs().max()) + if maxx > 1: + audio /= min(2, maxx) + spec = spectrogram_torch( + audio, + hps.data.filter_length, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + center=False, + ) + spec = spec.to(dtype) + if is_v2pro is True: + audio = resample(audio, sr1, 16000, device).to(dtype) + return spec, audio + + +def clean_text_inf(text, language, version): + language = language.replace("all_", "") + phones, word2ph, norm_text = clean_text(text, language, version) + phones = cleaned_text_to_sequence(phones, version) + return phones, word2ph, norm_text + + +def get_bert_inf(phones, word2ph, norm_text, language): + language = language.replace("all_", "") + if language == "zh": + bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype) + else: + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half is True else torch.float32, + ).to(device) + + return bert + + +def get_first(text): + pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" + text = re.split(pattern, text)[0].strip() + return text + + +def get_phones_and_bert(text, language, version, final=False): + text = re.sub(r" {2,}", " ", text) + textlist = [] + langlist = [] + if language == "all_zh": + for tmp in LangSegmenter.getTexts(text, "zh"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_yue": + for tmp in LangSegmenter.getTexts(text, "zh"): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ja": + for tmp in LangSegmenter.getTexts(text, "ja"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ko": + for tmp in LangSegmenter.getTexts(text, "ko"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "en": + langlist.append("en") + textlist.append(text) + elif language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if langlist: + if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): + textlist[-1] += tmp["text"] + continue + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + # 因无法区别中日韩文汉字,以用户输入为准 + langlist.append(language) + textlist.append(tmp["text"]) + phones_list = [] + bert_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version) + bert = get_bert_inf(phones, word2ph, norm_text, lang) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = "".join(norm_text_list) + + if not final and len(phones) < 6: + return get_phones_and_bert("." + text, language, version, final=True) + + return phones, bert.to(dtype), norm_text + + +def merge_short_text_in_array(texts, threshold): + if (len(texts)) < 2: + return texts + result = [] + text = "" + for ele in texts: + text += ele + if len(text) >= threshold: + result.append(text) + text = "" + if len(text) > 0: + if len(result) == 0: + result.append(text) + else: + result[len(result) - 1] += text + return result + + +sr_model = None + + +def audio_sr(audio, sr): + global sr_model + if sr_model is None: + from tools.audio_sr import AP_BWE + + try: + sr_model = AP_BWE(infer_device, DictToAttrRecursive) + except FileNotFoundError: + gr.Warning(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好")) + return audio.cpu().numpy(), sr + return sr_model(audio, sr) + + +cache: dict[int, Any] = {} + + +def get_tts_wav( + ref_wav_path, + prompt_text, + prompt_language, + text, + text_language, + how_to_cut=i18n("不切"), + top_k=20, + top_p=0.6, + temperature=0.6, + ref_free=False, + speed=1, + if_freeze=False, + inp_refs=None, + sample_steps=8, + if_sr=False, + pause_second=0.3, +): + torch.set_grad_enabled(False) + debug = os.getenv("DEBUG") == "1" + ttfb_time = ttime() + + if ref_wav_path: + pass + else: + gr.Warning(i18n("请上传参考音频")) + if text: + pass + else: + gr.Warning(i18n("请填入推理文本")) + t = [] + if prompt_text is None or len(prompt_text) == 0: + ref_free = True + if model_version in v3v4set: + ref_free = False # s2v3暂不支持ref_free + t0 = ttime() + prompt_language = dict_language[prompt_language] + text_language = dict_language[text_language] + + if not ref_free: + prompt_text = prompt_text.strip("\n") + if prompt_text[-1] not in splits: + prompt_text += "。" if prompt_language != "en" else "." + text = text.strip("\n") + + zero_wav = np.zeros( + int(hps.data.sampling_rate * pause_second), + dtype=np.float16 if is_half is True else np.float32, + ) + zero_wav_torch = torch.from_numpy(zero_wav) + if is_half is True: + zero_wav_torch = zero_wav_torch.half().to(infer_device) + else: + zero_wav_torch = zero_wav_torch.to(infer_device) + if not ref_free: + assert vq_model + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: + gr.Warning(i18n("参考音频在3~10秒范围外,请更换!")) + raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) + wav16k_t = torch.from_numpy(wav16k) + if is_half is True: + wav16k_t = wav16k_t.half().to(infer_device) + else: + wav16k_t = wav16k_t.to(infer_device) + wav16k_t = torch.cat([wav16k_t, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k_t.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() + codes = vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0).to(device) + else: + prompt = torch.zeros((1, 0)).to(device, torch.int32) + + t1 = ttime() + t.append(t1 - t0) + + if how_to_cut == i18n("凑四句一切"): + text = cut1(text) + elif how_to_cut == i18n("凑50字一切"): + text = cut2(text) + elif how_to_cut == i18n("按中文句号。切"): + text = cut3(text) + elif how_to_cut == i18n("按英文句号.切"): + text = cut4(text) + elif how_to_cut == i18n("按标点符号切"): + text = cut5(text) + while "\n\n" in text: + text = text.replace("\n\n", "\n") + texts = text.split("\n") + texts = merge_short_text_in_array(texts, 5) + audio_opt = [] + # s2v3暂不支持ref_free + if not ref_free: + phones1, bert1, _ = get_phones_and_bert(prompt_text, prompt_language, version) + else: + phones1, bert1 = [], torch.zeros(1024, 0).to(device, dtype) + + infer_len: list[int] = [] + infer_time: list[float] = [] + assert vq_model + + for i_text, text in enumerate(texts): + if len(text.strip()) == 0: + continue + if text[-1] not in splits: + text += "。" if text_language != "en" else "." + phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version) + + bert = torch.cat([bert1, bert2], 1) + all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) + + bert = bert.to(device).unsqueeze(0) + all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) + + t2 = ttime() + if i_text in cache and if_freeze is True: + pred_semantic = cache[i_text] + else: + t2s_request = T2SRequest( + [all_phoneme_ids.squeeze(0)], + all_phoneme_len, + prompt, + [bert.squeeze(0)], + valid_length=1, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=1500, + use_cuda_graph=torch.cuda.is_available(), # Try to use CUDA Graph for all backend, fallback to normal if not applicapble + debug=debug, + ) + assert t2s_engine + t2s_result = t2s_engine.generate(t2s_request) + if t2s_result.exception is not None: + console.print(t2s_result.traceback) + raise RuntimeError() + pred_semantic_list = t2s_result.result + assert pred_semantic_list, t2s_result.traceback + pred_semantic = pred_semantic_list[0].unsqueeze(0).to(infer_device) + infer_len.append(t2s_result.total_tokens) + infer_time.append(t2s_result.infer_speed[-1]) + + cache[i_text] = pred_semantic + t3 = ttime() + is_v2pro = model_version in {"v2Pro", "v2ProPlus"} + + refers = [] + if inp_refs: + for path in inp_refs: + try: # 这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer + refer, audio_tensor = get_spepc(hps, path.name, dtype, infer_device, is_v2pro) + refers.append(refer) + except Exception as e: + print(e) + traceback.print_exc() + if len(refers) == 0: + refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, infer_device, is_v2pro) + refers = [refers] + audio = vq_model.decode( + pred_semantic, + torch.LongTensor(phones2).to(infer_device).unsqueeze(0), + refers, + speed=speed, + )[0][0] # type: ignore + + if i_text == 0: + ttfb_time = ttime() - ttfb_time + max_audio = torch.abs(audio).max() # 简单防止16bit爆音 + if max_audio > 1: + audio = audio / max_audio + audio_opt.append(audio) + audio_opt.append(zero_wav_torch) # zero_wav + t4 = ttime() + t.extend([t2 - t1, t3 - t2, t4 - t3]) + t1 = ttime() + + audio_opt_t = torch.cat(audio_opt, 0) # np.concatenate + if model_version in {"v1", "v2", "v2Pro", "v2ProPlus"}: + opt_sr = 32000 + elif model_version == "v3": + opt_sr = 24000 + else: + opt_sr = 48000 # v4 + audio_opt_n = audio_opt_t.cpu().numpy() + + t0 = t[0] + t1 = sum(t[1::3]) + t2 = sum(t[2::3]) + t3 = sum(t[3::3]) + + infer_speed_avg = sum(infer_len) / sum(infer_time) + rtf_value = sum(t) / (audio_opt_n.__len__() / opt_sr) + + console.print(f">> Time Stamps: {t0:.3f}\t{t1:.3f}\t{t2:.3f}\t{t3:.3f}") + console.print(f">> Infer Speed: {infer_speed_avg:.2f} Token/s") + console.print(f">> RTF: {rtf_value:.2f}") + + if ttfb_time > 2: + console.print(f">> TTFB: {ttfb_time:.3f} s") + else: + console.print(f">> TTFB: {ttfb_time * 1000:.3f} ms") + + yield opt_sr, (audio_opt_n * 32767).astype(np.int16) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + +def split(todo_text): + todo_text = todo_text.replace("……", "。").replace("——", ",") + if todo_text[-1] not in splits: + todo_text += "。" + i_split_head = i_split_tail = 0 + len_text = len(todo_text) + todo_texts = [] + while 1: + if i_split_head >= len_text: + break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入 + if todo_text[i_split_head] in splits: + i_split_head += 1 + todo_texts.append(todo_text[i_split_tail:i_split_head]) + i_split_tail = i_split_head + else: + i_split_head += 1 + return todo_texts + + +def cut1(inp): + inp = inp.strip("\n") + inps = split(inp) + split_idx: list[int | None] = list(range(0, len(inps) + 1, 4)) + split_idx[-1] = None + if len(split_idx) > 1: + opts = [] + for idx in range(len(split_idx) - 1): + opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]])) + else: + opts = [inp] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) + + +def cut2(inp): + inp = inp.strip("\n") + inps = split(inp) + if len(inps) < 2: + return inp + opts = [] + summ = 0 + tmp_str = "" + for i in range(len(inps)): + summ += len(inps[i]) + tmp_str += inps[i] + if summ > 50: + summ = 0 + opts.append(tmp_str) + tmp_str = "" + if tmp_str != "": + opts.append(tmp_str) + if len(opts) > 1 and len(opts[-1]) < 50: # 如果最后一个太短了,和前一个合一起 + opts[-2] = opts[-2] + opts[-1] + opts = opts[:-1] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) + + +def cut3(inp): + inp = inp.strip("\n") + opts = inp.strip("。").split("。") + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) + + +def cut4(inp): + inp = inp.strip("\n") + opts = re.split(r"(? 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit(): + items.append(char) + else: + items.append(char) + mergeitems.append("".join(items)) + items = [] + else: + items.append(char) + + if items: + mergeitems.append("".join(items)) + + opt = [item for item in mergeitems if not set(item).issubset(punds)] + return "\n".join(opt) + + +a = get_tts_wav( + "/Users/XXXXRT/Desktop/参考/不过呢因为有些特殊情况,所以我在一年半之前并没有这个,退网啊.wav", + "不过呢因为有些特殊情况,所以我在一年半之前并没有这个,退网啊", + i18n("中文"), + "我在我青春韶华的时候遇到了你,还记得刚刚开学的时候,那是第一次见你,我和我朋友在楼道间打闹的时候无意间瞟到了你正在学习时的侧颜", + i18n("中文"), +) + +next(a) + +timer.clear() + +a = get_tts_wav( + "/Users/XXXXRT/Desktop/参考/Cream去能理解很多人的想法时,既然已经被这样想了,没有挽回的余地了.wav", + "去能理解很多人的想法时,既然已经被这样想了,没有挽回的余地了", + i18n("中文"), + "我在我青春韶华的时候遇到了你,还记得刚刚开学的时候,那是第一次见你,我和我朋友在楼道间打闹的时候无意间瞟到了你正在学习时的侧颜", + i18n("中文"), +) + + +next(a) + +timer.summary() + +a = get_tts_wav( + "/Users/XXXXRT/Desktop/参考/不过呢因为有些特殊情况,所以我在一年半之前并没有这个,退网啊.wav", + "不过呢因为有些特殊情况,所以我在一年半之前并没有这个,退网啊", + i18n("中文"), + "我在我青春韶华的时候遇到了你,还记得刚刚开学的时候,那是第一次见你,我和我朋友在楼道间打闹的时候无意间瞟到了你正在学习时的侧颜", + i18n("中文"), +) + + +next(a) + +timer.summary() + +print("-" * 15 + "test2" + "-" * 15) diff --git a/webui.py b/webui.py index 602e9e1b..3c20f361 100644 --- a/webui.py +++ b/webui.py @@ -8,6 +8,7 @@ import traceback from functools import partial from multiprocessing import cpu_count from subprocess import Popen +from typing import cast import gradio as gr import psutil @@ -37,7 +38,15 @@ from config import ( webui_port_subfix, webui_port_uvr5, ) -from GPT_SoVITS.Accelerate import backends, console, logger +from GPT_SoVITS.Accelerate import ( + MLX, + PyTorch, + backends, + console, + logger, + quantization_methods_mlx, + quantization_methods_torch, +) from tools import my_utils from tools.asr.config import asr_dict from tools.assets import css, js, top_html @@ -310,8 +319,8 @@ def change_tts_inference( sovits_path: str, batched_infer_enabled: bool, backends_dropdown: str, + quantization_methods_dropdown: str | None, ): - console.print(gpt_path, sovits_path) global p_tts_inference env = os.environ.copy() cmd: list[str] = [python_exec, "-s", "-m"] @@ -334,6 +343,7 @@ def change_tts_inference( "-b", backends_dropdown, "-d", f"{infer_device.type}:{gpu_number}", "-p", str(webui_port_infer_tts), + "-q", str(quantization_methods_dropdown), "--gpt", gpt_path, "--sovits", sovits_path, ] @@ -344,7 +354,6 @@ def change_tts_inference( if p_tts_inference is None: yield ( - process_info(process_name_tts, "opened"), gr.update(visible=False), gr.update(visible=True), ) @@ -354,7 +363,6 @@ def change_tts_inference( kill_process(p_tts_inference.pid, process_name_tts) p_tts_inference = None yield ( - process_info(process_name_tts, "closed"), gr.update(visible=True), gr.update(visible=False), ) @@ -1280,6 +1288,21 @@ def changeBackend(flag: bool): return gr.update(choices=backends_gradio, value=backends_gradio[-1][-1]) +def changeQuantization(backend: str, gradio_call=True): + backend = backend.lower().replace("-", "_") + if backend in MLX.backends: + choices = quantization_methods_mlx + elif backend in PyTorch.backends: + choices = quantization_methods_torch + else: + choices = [None] + + if gradio_call: + return gr.update(choices=choices, value=None) + else: + return choices + + GPU_INDEX.add(0) GPU_INDEX_LIST = list(GPU_INDEX) GPU_INDEX_LIST.sort() @@ -1891,7 +1914,13 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css interactive=True, ) with gr.Row(equal_height=True): - tts_info = gr.Textbox(label=process_info(process_name_tts, "info")) + with gr.Column(): + quantization_methods_dropdown = gr.Dropdown( + choices=cast(list, changeQuantization(backends_gradio[-1][-1], gradio_call=False)), + label=i18n("量化方法"), + value=None, + interactive=True, + ) open_tts = gr.Button( value=process_info(process_name_tts, "open"), variant="primary", visible=True ) @@ -1904,6 +1933,12 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css [batched_infer_enabled], [backends_dropdown], ) + backends_dropdown.change( + changeQuantization, + [backends_dropdown], + [quantization_methods_dropdown], + ) + open_tts.click( change_tts_inference, [ @@ -1912,8 +1947,9 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css SoVITS_dropdown, batched_infer_enabled, backends_dropdown, + quantization_methods_dropdown, ], - [tts_info, open_tts, close_tts], + [open_tts, close_tts], ) close_tts.click( change_tts_inference, @@ -1923,8 +1959,9 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css SoVITS_dropdown, batched_infer_enabled, backends_dropdown, + quantization_methods_dropdown, ], - [tts_info, open_tts, close_tts], + [open_tts, close_tts], ) button1Ba_open.click( open1Ba,