From c1a4ff476c7050535db2289bee58ef0ca392f7ae Mon Sep 17 00:00:00 2001
From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com>
Date: Sun, 19 Oct 2025 21:51:54 +0100
Subject: [PATCH] .
---
.github/build_windows_packages.ps1 | 4 +-
.gitignore | 1 +
Docker/miniconda_install.sh | 4 +-
GPT_SoVITS/Accelerate/MLX/__init__.py | 6 +-
.../Accelerate/MLX/backends/mlx_quantized.py | 179 ----
.../Accelerate/MLX/backends/mlx_static.py | 17 +-
.../Accelerate/MLX/backends/mlx_varlen.py | 9 +-
GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py | 93 +-
GPT_SoVITS/Accelerate/MLX/structs_mlx.py | 18 +-
GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py | 206 ++--
GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py | 224 ++---
GPT_SoVITS/Accelerate/PyTorch/__init__.py | 35 +-
.../backends/flash_attn_varlen_cuda_graph.py | 2 +-
.../PyTorch/backends/mps_flash_attn_varlen.py | 2 +-
.../backends/sage_attn_varlen_cuda_graph.py | 2 +-
.../backends/torch_static_cuda_graph.py | 2 +-
.../PyTorch/backends/torch_varlen.py | 2 +-
GPT_SoVITS/Accelerate/PyTorch/quantization.py | 158 ++++
GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py | 115 ++-
GPT_SoVITS/Accelerate/PyTorch/structs.py | 4 +-
GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py | 113 +--
.../Accelerate/PyTorch/t2s_model_abc.py | 56 +-
GPT_SoVITS/Accelerate/__init__.py | 13 +-
GPT_SoVITS/Accelerate/logger.py | 46 +
GPT_SoVITS/inference_webui.py | 24 +-
GPT_SoVITS/text/chinese2.py | 4 +-
README.md | 4 +-
docs/cn/README.md | 4 +-
docs/ja/README.md | 4 +-
docs/ko/README.md | 4 +-
docs/tr/README.md | 4 +-
install.ps1 | 6 +-
install.sh | 10 +-
requirements.txt | 1 +
test.py | 879 ++++++++++++++++++
webui.py | 51 +-
36 files changed, 1660 insertions(+), 646 deletions(-)
delete mode 100644 GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py
create mode 100644 GPT_SoVITS/Accelerate/PyTorch/quantization.py
create mode 100644 test.py
diff --git a/.github/build_windows_packages.ps1 b/.github/build_windows_packages.ps1
index 606e6089..3a967ef0 100644
--- a/.github/build_windows_packages.ps1
+++ b/.github/build_windows_packages.ps1
@@ -157,12 +157,12 @@ Write-Host "[INFO] Installing PyTorch..."
switch ($cuda) {
"cu126" {
& ".\runtime\python.exe" -m pip install psutil ninja packaging wheel "setuptools>=42" --no-warn-script-location --no-cache-dir
- & ".\runtime\python.exe" -m pip install torch --index-url https://download.pytorch.org/whl/cu126 --no-warn-script-location --no-cache-dir
+ & ".\runtime\python.exe" -m pip install torch torchao --index-url https://download.pytorch.org/whl/cu126 --no-warn-script-location --no-cache-dir
& ".\runtime\python.exe" -m pip install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation --no-cache-dir
}
"cu128" {
& ".\runtime\python.exe" -m pip install psutil ninja packaging wheel "setuptools>=42" --no-warn-script-location --no-cache-dir
- & ".\runtime\python.exe" -m pip install torch --index-url https://download.pytorch.org/whl/cu128 --no-warn-script-location --no-cache-dir
+ & ".\runtime\python.exe" -m pip install torch torchao --index-url https://download.pytorch.org/whl/cu128 --no-warn-script-location --no-cache-dir
& ".\runtime\python.exe" -m pip install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation --no-cache-dir
}
default {
diff --git a/.gitignore b/.gitignore
index d9e61e9f..2c8dfe98 100644
--- a/.gitignore
+++ b/.gitignore
@@ -20,6 +20,7 @@ tools/AP_BWE/24kto48k/*
!tools/AP_BWE/24kto48k/readme.txt
onnx_export
compile_cache
+profiler
# Byte-compiled / optimized / DLL files
__pycache__/
diff --git a/Docker/miniconda_install.sh b/Docker/miniconda_install.sh
index cf2e3d6f..7873116c 100644
--- a/Docker/miniconda_install.sh
+++ b/Docker/miniconda_install.sh
@@ -60,10 +60,10 @@ source "$HOME/.bashrc"
"$HOME/miniconda3/bin/conda" install gcc=11 gxx ffmpeg cmake make unzip $SYSROOT_PKG "libstdcxx-ng>=11" -q -y
if [ "$CUDA_VERSION" = "12.8" ]; then
- "$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu128
+ "$HOME/miniconda3/bin/pip" install torch torchao --no-cache-dir --index-url https://download.pytorch.org/whl/cu128
"$HOME/miniconda3/bin/conda" install cuda-nvcc=12.8 -c nvidia
elif [ "$CUDA_VERSION" = "12.6" ]; then
- "$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
+ "$HOME/miniconda3/bin/pip" install torch torchao --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
"$HOME/miniconda3/bin/conda" install cuda-nvcc=12.6 -c nvidia
fi
diff --git a/GPT_SoVITS/Accelerate/MLX/__init__.py b/GPT_SoVITS/Accelerate/MLX/__init__.py
index 20691fcf..f98c8c40 100644
--- a/GPT_SoVITS/Accelerate/MLX/__init__.py
+++ b/GPT_SoVITS/Accelerate/MLX/__init__.py
@@ -5,8 +5,10 @@ if importlib.util.find_spec("mlx") is not None and platform.system() == "Darwin"
from .sample_funcs_mlx import sample_naive as sample_naive_mlx
from .t2s_engine_mlx import T2SEngine as T2SEngineMLX
- backends = ["mlx_static", "mlx_quantized_mxfp4", "mlx_quantized_affine", "mlx_varlen"]
+ backends = ["mlx_static", "mlx_varlen"]
else:
backends = []
-__all__ = ["T2SEngineMLX", "sample_naive_mlx", "backends"]
+quantization_methods_mlx = [None, "MXFP4", "Affine"]
+
+__all__ = ["T2SEngineMLX", "sample_naive_mlx", "backends", "quantization_methods_mlx"]
diff --git a/GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py b/GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py
deleted file mode 100644
index 54ceddb5..00000000
--- a/GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py
+++ /dev/null
@@ -1,179 +0,0 @@
-from __future__ import annotations
-
-import mlx.core as mx
-import mlx.nn as nn
-
-from ..structs_mlx import KVCacheQ
-from ..t2s_model_abc import (
- AttentionABC,
- KVCache,
- KVCacheHND,
- T2SDecoderABC,
- TransformerBlockABC,
- TransformerDecoderABC,
-)
-
-Array = mx.array
-
-
-class Attention(AttentionABC):
- def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
- super().__init__(n_head, hidden_dim, max_seq_length)
- self.kc_class = KVCacheHND
-
- @staticmethod
- def quantized_scaled_dot_product_attention(
- queries: Array,
- q_keys: tuple[Array, Array, Array],
- q_values: tuple[Array, Array, Array],
- scale: float,
- mask: Array,
- group_size: int = 32,
- bits: int = 8,
- ) -> Array:
- queries *= scale
-
- scores = mx.quantized_matmul(queries, *q_keys, transpose=True, group_size=group_size, bits=bits)
- scores = mx.where(mask, scores, -mx.inf)
- scores = mx.softmax(scores, axis=-1, precise=True) # type: ignore
- out = mx.quantized_matmul(scores, *q_values, transpose=False, group_size=group_size, bits=bits)
-
- return out
-
- def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
- bsz, seqlen, _ = x.shape
-
- q, k, v = self.in_proj(x).split(3, axis=-1)
-
- q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
-
- q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
-
- kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
- assert len(kv_cache) == 2
-
- max_idx = int(input_pos.max())
-
- q, k, v = map(lambda x: x[..., :max_idx, :], (q, *kv_cache))
-
- mask = attn_mask[..., :max_idx]
-
- attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
-
- attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
-
- attn = self.out_proj(attn)
-
- return attn
-
- # def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
- # bsz, seqlen, _ = x.shape
-
- # q, k, v = self.in_proj(x).split(3, axis=-1)
-
- # q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
-
- # q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
-
- # kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
-
- # assert len(kv_cache) == 3
- # (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits) = kv_cache
-
- # k_q, k_s, k_b, v_q, v_s, v_b = map(lambda x: x[..., : int(input_pos.max()), :], (k_q, k_s, k_b, v_q, v_s, v_b))
-
- # mask = attn_mask[..., : int(input_pos.max())]
-
- # attn = Attention.quantized_scaled_dot_product_attention(
- # q,
- # (k_q, k_s, k_b),
- # (v_q, v_s, v_b),
- # self.scale,
- # mask,
- # group_size,
- # bits,
- # )
-
- # attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
-
- # output = self.out_proj(attn)
-
- # return output
-
-
-class TransformerBlock(TransformerBlockABC):
- def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int, *args, **kwds) -> None:
- super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length, *args, **kwds)
-
- self.attention = Attention(n_head, hidden_dim, max_seq_length, *args, **kwds)
-
-
-class TransformerDecoder(TransformerDecoderABC):
- def __init__(
- self,
- hidden_dim: int,
- n_layer: int,
- n_head: int,
- ffn_dim: int,
- vocab_size: int,
- max_seq_length: int,
- max_batch_size: int,
- *args,
- **kwds,
- ) -> None:
- super().__init__(
- hidden_dim,
- n_layer,
- n_head,
- ffn_dim,
- vocab_size,
- max_seq_length,
- max_batch_size,
- *args,
- **kwds,
- )
-
- self.layers = [
- TransformerBlock(
- n_head,
- ffn_dim,
- hidden_dim,
- max_seq_length,
- *args,
- **kwds,
- )
- for _ in range(n_layer)
- ]
-
-
-class T2SDecoder(T2SDecoderABC):
- def __init__(
- self,
- config: dict,
- max_seq_length: int = 2000,
- max_batch_size: int = 10,
- ) -> None:
- super().__init__(config, max_seq_length, max_batch_size)
-
- self.h = TransformerDecoder(
- self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
- )
-
- self.kv_class = KVCacheHND
- self.group_size = 32
- self.bits = 8
- self.mode = "affine"
-
- def set_mode(self, mode: str):
- assert mode in ["affine", "mxfp4"]
- self.mode = mode
- if self.mode == "mxfp4":
- self.bits = 4
- else:
- self.bits = 8
-
- def quantized(self):
- nn.quantize(self, self.group_size, self.bits, mode=self.mode)
- # for layer in self.h.layers:
- # nn.quantize(layer.feed_forward, self.group_size, self.bits)
- # nn.quantize(layer.attention, self.group_size, self.bits)
diff --git a/GPT_SoVITS/Accelerate/MLX/backends/mlx_static.py b/GPT_SoVITS/Accelerate/MLX/backends/mlx_static.py
index ca402210..e6724eb2 100644
--- a/GPT_SoVITS/Accelerate/MLX/backends/mlx_static.py
+++ b/GPT_SoVITS/Accelerate/MLX/backends/mlx_static.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import mlx.core as mx
-from ..structs_mlx import KVCache, KVCacheQ
+from ..structs_mlx import KVCache
from ..t2s_model_abc import (
AttentionABC,
KVCacheHND,
@@ -19,23 +19,24 @@ class Attention(AttentionABC):
super().__init__(n_head, hidden_dim, max_seq_length)
self.kc_class = KVCacheHND
- def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
+ def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache, cache_idx: Array, attn_mask: Array):
bsz, seqlen, _ = x.shape
- q, k, v = self.in_proj(x).split(3, axis=-1)
+ qkv = self.in_proj(x)
- q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
+ q, k, v = mx.split(qkv, 3, -1)
- q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
+ q = q.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3)
+ k = k.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3)
+ v = v.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3)
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
- assert len(kv_cache) == 2
k, v = kv_cache
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attn_mask)
- attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
+ attn = attn.transpose(0, 2, 1, 3).reshape(bsz, seqlen, -1)
attn = self.out_proj(attn)
@@ -85,7 +86,7 @@ class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config: dict,
- max_seq_length: int = 2000,
+ max_seq_length: int = 1500,
max_batch_size: int = 10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
diff --git a/GPT_SoVITS/Accelerate/MLX/backends/mlx_varlen.py b/GPT_SoVITS/Accelerate/MLX/backends/mlx_varlen.py
index 92d51dc3..11c73589 100644
--- a/GPT_SoVITS/Accelerate/MLX/backends/mlx_varlen.py
+++ b/GPT_SoVITS/Accelerate/MLX/backends/mlx_varlen.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import mlx.core as mx
-from ..structs_mlx import KVCache, KVCacheQ
+from ..structs_mlx import KVCache
from ..t2s_model_abc import (
AttentionABC,
KVCacheHND,
@@ -19,7 +19,7 @@ class Attention(AttentionABC):
super().__init__(n_head, hidden_dim, max_seq_length)
self.kc_class = KVCacheHND
- def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
+ def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache, cache_idx: Array, attn_mask: Array):
bsz, seqlen, _ = x.shape
q, k, v = self.in_proj(x).split(3, axis=-1)
@@ -29,7 +29,6 @@ class Attention(AttentionABC):
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
- assert len(kv_cache) == 2
max_idx = int(input_pos.max())
@@ -39,7 +38,7 @@ class Attention(AttentionABC):
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
- attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
+ attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, -1)
attn = self.out_proj(attn)
@@ -89,7 +88,7 @@ class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config: dict,
- max_seq_length: int = 2000,
+ max_seq_length: int = 1500,
max_batch_size: int = 10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
diff --git a/GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py b/GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py
index ee4ada55..dcc02a63 100644
--- a/GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py
+++ b/GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py
@@ -1,3 +1,4 @@
+from functools import partial
from typing import Protocol
import mlx.core as mx
@@ -17,8 +18,56 @@ class SampleProtocolMLX(Protocol):
) -> Array: ...
+def apply_repetition_penalty(logits: Array, previous_tokens: Array, repetition_penalty: float):
+ batch_idx = mx.arange(previous_tokens.shape[0])
+ selected_logits = logits[batch_idx, previous_tokens]
+ selected_logits = mx.where(
+ selected_logits < 0, selected_logits * repetition_penalty, selected_logits / repetition_penalty
+ )
+ logits[batch_idx, previous_tokens] = selected_logits
+ return logits
+
+
+@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
+def apply_greedy_sampling(logits: Array):
+ return mx.argmax(logits, axis=-1, keepdims=True).astype(mx.int32)
+
+
+@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
+def apply_temperature(logits: Array, temperature: float):
+ return logits / temperature
+
+
+@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
+def apply_top_k(logits: Array, top_k: int):
+ v = mx.topk(logits, top_k)
+ pivot = mx.expand_dims(v[:, 0], -1)
+ logits = mx.where(logits < pivot, -mx.inf, logits)
+ return logits
+
+
+@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
+def apply_top_p(logits: Array, top_p: float):
+ sorted_indices = mx.argsort(-logits, axis=-1)
+ sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1)
+ cum_probs = mx.cumsum(mx.softmax(sorted_logits, axis=-1), axis=-1)
+ sorted_indices_to_remove = cum_probs > top_p
+ sorted_indices_to_remove[:, -1] = False
+ indices_to_remove = mx.zeros_like(logits).astype(mx.bool_)
+ batch_indices = mx.arange(logits.shape[0])[:, None]
+ indices_to_remove[batch_indices, sorted_indices] = sorted_indices_to_remove
+ logits = mx.where(indices_to_remove, -mx.inf, logits)
+ return logits
+
+
+@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
+def apply_sampling(logits: Array):
+ gumbel_noise = mx.random.gumbel(shape=logits.shape, dtype=logits.dtype)
+ idx_next = mx.argmax(logits + gumbel_noise, axis=-1, keepdims=True).astype(mx.int32)
+ return idx_next
+
+
class sample_naive(SampleProtocolMLX):
- # @partial(mx.compile)
@staticmethod
def __call__(
logits,
@@ -28,38 +77,18 @@ class sample_naive(SampleProtocolMLX):
top_p,
repetition_penalty,
):
- if temperature <= 1e-5:
- probs = mx.softmax(logits, axis=-1)
- return mx.argmax(probs, axis=-1, keepdims=True).astype(mx.int32)
-
if repetition_penalty != 1.0:
- batch_idx = mx.arange(previous_tokens.shape[0])
- previous_tokens = previous_tokens.astype(mx.int64)
- selected_logists = logits[batch_idx, previous_tokens]
- selected_logists = mx.where(
- selected_logists < 0, selected_logists * repetition_penalty, selected_logists / repetition_penalty
- )
- logits[batch_idx, previous_tokens] = selected_logists
+ logits = apply_repetition_penalty(logits, previous_tokens, repetition_penalty)
+
+ if temperature <= 1e-5:
+ return apply_greedy_sampling(logits)
+ elif temperature < 1.0:
+ logits = apply_temperature(logits, temperature)
+
+ if top_k < 1025:
+ logits = apply_top_k(logits, top_k)
if top_p < 1.0:
- sorted_indices = mx.argsort(-logits, axis=-1)
- sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1)
- cum_probs = mx.cumsum(mx.softmax(sorted_logits, axis=-1), axis=-1)
- sorted_indices_to_remove = cum_probs > top_p
- sorted_indices_to_remove[:, -1] = False
- indices_to_remove = mx.zeros_like(logits).astype(mx.bool_)
- batch_indices = mx.arange(logits.shape[0])[:, None]
- indices_to_remove[batch_indices, sorted_indices] = sorted_indices_to_remove
- logits = mx.where(indices_to_remove, -mx.inf, logits)
+ logits = apply_top_p(logits, top_p)
- if temperature < 1.0:
- logits = logits / temperature
-
- v = mx.topk(logits, top_k)
- pivot = mx.expand_dims(v[:, 0], -1)
- logits = mx.where(logits < pivot, -mx.inf, logits)
-
- gumbel_noise = mx.random.gumbel(shape=logits.shape, dtype=logits.dtype)
- idx_next = mx.argmax(logits + gumbel_noise, axis=-1, keepdims=True).astype(mx.int32)
-
- return idx_next
+ return apply_sampling(logits)
diff --git a/GPT_SoVITS/Accelerate/MLX/structs_mlx.py b/GPT_SoVITS/Accelerate/MLX/structs_mlx.py
index ce2fd003..426309e6 100644
--- a/GPT_SoVITS/Accelerate/MLX/structs_mlx.py
+++ b/GPT_SoVITS/Accelerate/MLX/structs_mlx.py
@@ -29,6 +29,7 @@ class T2SRequestMLX:
early_stop_num: int = -1
temperature: float = 1.0
repetition_penalty: float = 1.35
+ debug: bool = False
@classmethod
def from_torch(cls, request: T2SRequest) -> T2SRequestMLX:
@@ -48,29 +49,27 @@ class T2SRequestMLX:
request.early_stop_num,
request.temperature,
request.repetition_penalty,
+ request.debug,
)
KVCache: TypeAlias = tuple[Array, Array]
-KVCacheQ: TypeAlias = tuple[tuple[Array, Array, Array], tuple[Array, Array, Array], tuple[int, int]]
class KVCacheProtocol(Protocol):
@staticmethod
- def empty(kv_cache: KVCache | KVCacheQ) -> None: ...
+ def empty(kv_cache: KVCache) -> None: ...
@staticmethod
- def update_cache(
- input_pos: Array, k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array
- ) -> KVCache | KVCacheQ: ...
+ def update_cache(input_pos: Array, k_val: Array, v_val: Array, kv_cache: KVCache, cache_idx: Array) -> KVCache: ...
@staticmethod
- def prefill_kv(k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ) -> None: ...
+ def prefill_kv(k_val: Array, v_val: Array, kv_cache: KVCache) -> None: ...
@staticmethod
def init_cache(
batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype, *args, **kwds
- ) -> KVCache | KVCacheQ: ...
+ ) -> KVCache: ...
class T2SDecoderProtocol(Protocol):
@@ -104,7 +103,7 @@ class T2SSessionMLX:
self.y_len = y_len
# Cache
- self.kv_cache: MutableSequence[KVCache | KVCacheQ]
+ self.kv_cache: MutableSequence[KVCache]
self.sample = sample_func()
# Forward args
@@ -118,6 +117,7 @@ class T2SSessionMLX:
self.input_pos = mx.zeros_like(self.prefill_len)
self.input_pos += self.prefill_len
+ self.input_pos = self.input_pos.squeeze(0) # 30% Performance Improvement
# EOS
self.completed = mx.array([False] * len(self.x)).astype(mx.bool_)
@@ -148,5 +148,3 @@ class T2SSessionMLX:
attn_mask = mx.repeat(mx.expand_dims(attn_mask, 1), decoder.n_head, 1)
self.attn_mask = attn_mask
-
- mx.eval(self.attn_mask)
diff --git a/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py b/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py
index 13b4c088..d21557cb 100644
--- a/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py
+++ b/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py
@@ -1,22 +1,23 @@
-import gc
import os
import time
import traceback
-from typing import cast
+from typing import Literal, cast
import mlx.core as mx
import torch
from rich.progress import BarColumn, Progress, TextColumn
-from ..logger import SpeedColumnToken, console, logger
+from ..logger import SpeedColumnToken, Timer, console, logger
from ..PyTorch.structs import T2SEngineProtocol, T2SRequest, T2SResult
-from .backends import mlx_quantized, mlx_static, mlx_varlen
+from .backends import mlx_static, mlx_varlen
from .structs_mlx import T2SSessionMLX
from .t2s_model_abc import T2SDecoderABC
Array = mx.array
Tensor = torch.Tensor
+timer = Timer()
+
class T2SEngine(T2SEngineProtocol):
def __init__(
@@ -31,10 +32,10 @@ class T2SEngine(T2SEngineProtocol):
device = mx.Device(mx.cpu)
case "mx.gpu":
device = mx.Device(mx.gpu)
-
+ device = cast(mx.Device, device)
match dtype:
case torch.float32:
- dtype = mx.float32
+ dtype = mx.float16 if device.type == mx.gpu else mx.float32
case torch.float16:
dtype = mx.float16
case torch.bfloat16:
@@ -59,13 +60,14 @@ class T2SEngine(T2SEngineProtocol):
decoder = self.decoder_model
session = T2SSessionMLX(decoder, request, device=self.device, dtype=self.dtype)
batch_idx = mx.arange(session.bsz)
+ debug = request.debug
t1 = 0.0
infer_speed = 0.0
infer_time = 0.0
+ idx = 0
with (
- mx.stream(session.device),
Progress(
TextColumn("[cyan]{task.description}"),
BarColumn(),
@@ -75,29 +77,47 @@ class T2SEngine(T2SEngineProtocol):
transient=True,
) as progress,
):
- max_token = min(2000 - int(session.input_pos.max()), 1500)
+ max_token = min(1500 - int(session.input_pos.max()), 1000) * session.bsz
task = progress.add_task("T2S Decoding", total=max_token)
- for idx in range(1500):
- progress.update(task, advance=1)
+ for idx in range(max_token):
+ progress.update(task, advance=session.bsz)
if idx == 0:
session.kv_cache = decoder.init_cache(session.bsz)
- xy_dec = decoder.h.prefill(
- session.xy_pos,
- session.attn_mask,
- session.kv_cache,
- ) # bs, seq_len, embed_dim
- xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
+ t1 = time.perf_counter()
+ with timer("MLX.Prefill", debug=debug):
+ xy_dec = decoder.h.prefill(
+ session.xy_pos,
+ session.attn_mask,
+ session.kv_cache,
+ ) # bs, seq_len, embed_dim
+ xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
+ if debug:
+ mx.eval(xy_dec)
else:
args, kwds = decoder.pre_forward(session)
- xy_dec = decoder.h(
- session.input_pos,
- session.xy_pos,
- session.kv_cache,
- batch_idx,
- *args,
- **kwds,
- )
+
+ if debug:
+ mx.eval(session.input_pos, session.xy_pos, session.kv_cache, args, kwds, batch_idx)
+
+ if debug and idx == 50 and os.environ.get("MTL_CAPTURE_ENABLED") == "1":
+ os.makedirs("./profiler/mlx", exist_ok=True)
+ mx.metal.start_capture(f"./profiler/mlx/{time.time()}.gputrace")
+
+ with timer("MLX.Decode", debug=debug):
+ xy_dec = decoder.h(
+ session.input_pos,
+ session.xy_pos,
+ session.kv_cache,
+ batch_idx,
+ *args,
+ **kwds,
+ )
+ if debug:
+ mx.eval(xy_dec)
+
+ if debug and idx == 50 and os.environ.get("MTL_CAPTURE_ENABLED") == "1":
+ mx.metal.stop_capture()
decoder.post_forward(idx, session)
logits = decoder.ar_predict_layer(xy_dec[:, -1])
@@ -106,28 +126,38 @@ class T2SEngine(T2SEngineProtocol):
if idx == 0:
logits[:, -1] = -mx.inf
- samples = session.sample(
- logits=logits,
- previous_tokens=session.y[:, : session.y_len + idx],
- top_k=request.top_k,
- top_p=request.top_p,
- repetition_penalty=request.repetition_penalty,
- temperature=request.temperature,
- )
+ with timer("MLX.Sampling", debug=debug):
+ samples = session.sample(
+ logits=logits,
+ previous_tokens=session.y[:, : session.y_len + idx],
+ top_k=request.top_k,
+ top_p=request.top_p,
+ repetition_penalty=request.repetition_penalty,
+ temperature=request.temperature,
+ )
- session.y[batch_idx, session.y_len + idx] = samples
+ session.y[batch_idx, session.y_len + idx] = samples
- argmax_token = mx.argmax(logits, axis=-1)
- sample_token = samples.squeeze(1)
- EOS_mask = (cast(Array, argmax_token == decoder.EOS)) | (sample_token == decoder.EOS)
+ if debug:
+ mx.eval(samples)
- newly_done_mask = EOS_mask & (~session.completed)
- newly_done_indices = mx.where(newly_done_mask, batch_idx, -1)
- pos = mx.where(newly_done_indices != -1, batch_idx, session.bsz)
- pos_sorted = mx.sort(pos, axis=0)
- valid_count = session.bsz - mx.sum(cast(Array, pos_sorted == session.bsz))
- pos_final = pos_sorted[: int(valid_count)]
- newly_done_indices = mx.expand_dims(newly_done_indices[pos_final], 0)
+ with timer("MLX.EOS", debug=debug):
+ mx.set_default_device(mx.Device(mx.cpu))
+ argmax_token = mx.argmax(logits, axis=-1)
+ sample_token = samples.squeeze(1)
+ EOS_mask = (cast(Array, argmax_token == decoder.EOS)) | (sample_token == decoder.EOS)
+
+ newly_done_mask = EOS_mask & (~session.completed)
+ newly_done_indices = mx.where(newly_done_mask, batch_idx, -1)
+ pos = mx.where(newly_done_indices != -1, batch_idx, session.bsz)
+ pos_sorted = mx.sort(pos, axis=0)
+ valid_count = session.bsz - mx.sum(cast(Array, pos_sorted == session.bsz))
+ pos_final = pos_sorted[: int(valid_count)]
+ newly_done_indices = mx.expand_dims(newly_done_indices[pos_final], 0)
+ mx.set_default_device(self.device)
+
+ if debug:
+ mx.eval(newly_done_indices)
if newly_done_indices.size > 0:
for i in newly_done_indices:
@@ -135,54 +165,53 @@ class T2SEngine(T2SEngineProtocol):
session.completed[newly_done_indices] = True
if mx.all(session.completed).item():
- if session.y[:, session.y_len :].sum() == 0:
- session.y_results = [mx.array([0]) for _ in range(session.bsz)]
- logger.error("Bad Zero Prediction")
- else:
- logger.info(
- f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.shape[-1] for i in session.y_results].__str__().strip('[]')}"
- )
- logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
- infer_time = time.perf_counter() - t1
- infer_speed = (idx - 1) / infer_time
+ logger.info(
+ f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.shape[-1] for i in session.y_results].__str__().strip('[]')}"
+ )
+ logger.info(f"Infer Speed: {(idx + 1) / (time.perf_counter() - t1):.2f} token/s")
+ infer_time = time.perf_counter() - t1
+ infer_speed = (idx + 1) / infer_time
break
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
for j in range(session.bsz):
if not session.completed[j].item():
- session.y_results[j] = session.y[[j], session.y_len : session.y_len + 1499]
+ session.y_results[j] = session.y[[j], session.y_len : session.y_len + idx]
session.completed[j] = True
logger.error("Bad Full Prediction")
- logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
+ logger.info(f"Infer Speed: {(idx + 1) / (time.perf_counter() - t1):.2f} token/s")
infer_time = time.perf_counter() - t1
- infer_speed = (idx - 1) / infer_time
+ infer_speed = (idx + 1) / infer_time
break
- y_emb = decoder.ar_audio_embedding(samples)
- session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
- mx.eval(session.xy_pos, session.y)
+ with timer("MLX.NextPos", debug=debug):
+ y_emb = decoder.ar_audio_embedding(samples)
+ session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
+ mx.eval(session.xy_pos, session.y)
- if idx == 1:
- t1 = time.perf_counter()
-
- if idx % 100 == 0:
+ if idx % 128 == 0:
mx.clear_cache()
- match session.device:
- case mx.gpu:
- mx.clear_cache()
- case mx.cpu:
- gc.collect()
-
result_mlx = session.y_results[: request.valid_length]
mx.eval(result_mlx)
result = [torch.tensor(k) for k in result_mlx]
- return result, infer_speed, infer_time
+ mx.clear_cache()
+
+ if debug:
+ timer.summary()
+ timer.clear()
+
+ return result, infer_speed, infer_time, (idx + 1) * session.bsz
def generate(self, request: T2SRequest):
try:
- result, infer_speed, infer_time = self._handle_request(request)
- t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success")
+ result, infer_speed, infer_time, total_tokens = self._handle_request(request)
+ t2s_result = T2SResult(
+ result=result,
+ infer_speed=(infer_speed, infer_time),
+ total_tokens=total_tokens,
+ status="Success",
+ )
except Exception as e:
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
return t2s_result
@@ -199,40 +228,39 @@ class T2SEngine(T2SEngineProtocol):
.replace("norm1", "attention_norm")
.replace("norm2", "ffn_norm")
)
- value_mlx = mx.array(value) # type: ignore
+ value_mlx = mx.array(value.to(torch.float32).cpu().numpy())
state_dict_mlx.append((key, value_mlx))
return state_dict_mlx
@staticmethod
- def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "MLX-Varlen"):
+ def load_decoder(
+ weights_path: os.PathLike,
+ max_batch_size: int = 1,
+ backend: str = "MLX-Varlen",
+ quantize_mode: Literal["Affine", "MXFP4"] | None = None,
+ ) -> T2SDecoderABC:
logger.info(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend")
- dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
+ dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=True, mmap=True)
config = dict_s1["config"]
match backend:
case "MLX-Varlen":
decoder_cls: type[T2SDecoderABC] = mlx_varlen.T2SDecoder
case "MLX-Static":
decoder_cls = mlx_static.T2SDecoder
- case "MLX-Quantized-Affine" | "MLX-Quantized-MXFP4":
- decoder_cls = mlx_quantized.T2SDecoder
case _:
raise RuntimeError(f"Backend {backend} Not Found")
decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=max_batch_size)
state_dict = dict_s1["weight"]
state_dict_mlx = T2SEngine.replace_key(state_dict)
+
decoder.load_weights(state_dict_mlx)
- decoder.eval()
+
+ if quantize_mode is not None:
+ decoder.quantize(quantize_mode)
+ logger.info(
+ f"Quantized to {decoder.bits}-Bit with Group Size {decoder.group_size} by {quantize_mode} Quantization"
+ )
+
mx.eval(decoder)
-
- if "Quantized" in backend and isinstance(decoder, mlx_quantized.T2SDecoder):
- if backend == "MLX-Quantized-Affine":
- decoder.set_mode("affine")
- elif backend == "MLX-Quantized-MXFP4":
- decoder.set_mode("mxfp4")
- else:
- raise RuntimeError(f"Quantized Backend {backend} Not Supported")
- decoder.quantized()
- mx.eval(decoder)
-
return decoder
diff --git a/GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py b/GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py
index 0bff4219..80de8f39 100644
--- a/GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py
+++ b/GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py
@@ -2,12 +2,13 @@ from __future__ import annotations
import math
from abc import ABC, abstractmethod
-from typing import MutableSequence
+from typing import Literal, MutableSequence, Type
import mlx.core as mx
import mlx.nn as nn
+from mlx.core import Dtype
-from .structs_mlx import KVCache, KVCacheProtocol, KVCacheQ, T2SDecoderProtocol, T2SSessionMLX
+from .structs_mlx import KVCache, KVCacheProtocol, T2SDecoderProtocol, T2SSessionMLX
Array = mx.array
@@ -43,26 +44,28 @@ class SinePositionalEmbedding(nn.Module):
embedding_dim: int,
scale: bool = False,
max_batch_size: int = 10,
- max_seq_len: int = 2000,
+ max_seq_length: int = 1500,
):
super().__init__()
self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
self.alpha = mx.ones(1)
self.max_batch_size = max_batch_size
- self.max_seq_len = max_seq_len
+ self.max_seq_length = max_seq_length
self.reverse = False
- self._pe = mx.zeros((max_batch_size, max_seq_len, embedding_dim))
- self.compute_pe()
+ self.pe: Array | None = None
- def compute_pe(self):
- """Reset the positional encodings."""
+ def compute_pe(self, dtype: Dtype):
+ """Compute the positional encodings."""
+
+ if self.pe is not None and self.pe.dtype == dtype:
+ return
if self.reverse:
- position = mx.expand_dims(mx.arange(self.max_seq_len - 1, -1, -1.0), axis=1)
+ position = mx.expand_dims(mx.arange(self.max_seq_length - 1, -1, -1.0), axis=1)
else:
- position = mx.expand_dims(mx.arange(self.max_seq_len), axis=1)
+ position = mx.expand_dims(mx.arange(self.max_seq_length), axis=1)
div_term = mx.exp(
mx.arange(
0,
@@ -70,10 +73,13 @@ class SinePositionalEmbedding(nn.Module):
2,
)
* -(math.log(10000.0) / self.embedding_dim)
- )
- pe = self._pe
- pe[:, :, 0::2] = mx.sin(position * div_term)
- pe[:, :, 1::2] = mx.cos(position * div_term)
+ ).astype(dtype)
+ pe = mx.zeros((self.max_batch_size, self.max_seq_length, self.embedding_dim)).astype(dtype)
+
+ pe[:, :, 0::2] = mx.sin(position * div_term).astype(dtype)
+ pe[:, :, 1::2] = mx.cos(position * div_term).astype(dtype)
+
+ self.pe = pe
def __call__(self, input_pos: Array, x: Array):
"""
@@ -84,9 +90,11 @@ class SinePositionalEmbedding(nn.Module):
Returns:
embedded_x (Array): [batch_size, 1, embed_dim]
"""
+ self.compute_pe(x.dtype)
+ assert self.pe is not None
batch_size = x.shape[0]
- pe_values = self._pe[mx.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
+ pe_values = self.pe[mx.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
return x * self.x_scale + self.alpha * mx.expand_dims(pe_values, 1) # (batch_size, 1, embed_dim)
@@ -98,7 +106,10 @@ class SinePositionalEmbedding(nn.Module):
Returns:
embedded_x (Array): [batch_size, seq_len, embed_dim]
"""
- pe_values = self._pe[:, : x.shape[-2]]
+ self.compute_pe(x.dtype)
+ assert self.pe is not None
+
+ pe_values = self.pe[:, : x.shape[-2]]
return x * self.x_scale + self.alpha * pe_values
@@ -125,12 +136,12 @@ class KVCacheHND(KVCacheProtocol):
@staticmethod
def prefill_kv(k_val, v_val, kv_cache):
- # k_val: [B, S, H, D]
+ # k_val: [B, H, S, D]
assert len(kv_cache) == 2
k_cache, v_cache = kv_cache
- k_cache[..., : k_val.shape[1], :] = k_val.swapaxes(1, 2)
- v_cache[..., : v_val.shape[1], :] = v_val.swapaxes(1, 2)
+ k_cache[..., : k_val.shape[2], :] = k_val
+ v_cache[..., : v_val.shape[2], :] = v_val
@staticmethod
def init_cache(batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype) -> KVCache:
@@ -139,118 +150,6 @@ class KVCacheHND(KVCacheProtocol):
return (mx.zeros(cache_shape, dtype=dtype), mx.zeros(cache_shape, dtype=dtype))
-class KVCacheHNDQuantized(KVCacheProtocol):
- @staticmethod
- def _el_per_int(bits: int) -> int:
- return 32 // bits
-
- @staticmethod
- def _packed_dim(head_dim: int, bits: int = 8) -> int:
- el_per_int = KVCacheHNDQuantized._el_per_int(bits)
- if head_dim % el_per_int != 0:
- raise ValueError(f"{head_dim=} is not divisible by {el_per_int=} ({bits=})")
- return head_dim // el_per_int
-
- @staticmethod
- def _group_count(head_dim: int, group_size: int = 32) -> int:
- assert group_size in {32, 64, 128}
- if head_dim % group_size != 0:
- raise ValueError(f"{head_dim} is not divisible by {group_size=}")
- return head_dim // group_size
-
- @staticmethod
- def empty(kv_cache) -> None:
- assert len(kv_cache) == 3
- (k_q, k_s, k_b), (v_q, v_s, v_b), (_, __) = kv_cache
-
- k_q[:] = 0
- k_s[:] = 0
- k_b[:] = 0
- v_q[:] = 0
- v_s[:] = 0
- v_b[:] = 0
-
- @staticmethod
- def update_cache(
- input_pos,
- k_val,
- v_val,
- kv_cache,
- cache_idx,
- ):
- # input_pos: [B, ], k_val: [B, H, 1, D]
-
- assert len(kv_cache) == 3
- (k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
-
- k_q, k_s, k_b = mx.quantize(k_val, group_size=group_size, bits=bits)
- v_q, v_s, v_b = mx.quantize(v_val, group_size=group_size, bits=bits)
-
- ip0 = input_pos - 1
-
- k_q_out[cache_idx, :, ip0, None] = k_q
- k_s_out[cache_idx, :, ip0, None] = k_s
- k_b_out[cache_idx, :, ip0, None] = k_b
-
- v_q_out[cache_idx, :, ip0, None] = v_q
- v_s_out[cache_idx, :, ip0, None] = v_s
- v_b_out[cache_idx, :, ip0, None] = v_b
-
- return (k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits)
-
- @staticmethod
- def prefill_kv(
- k_val,
- v_val,
- kv_cache,
- ) -> None:
- assert len(kv_cache) == 3
- (k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
-
- S = k_val.shape[1]
-
- k_sw = k_val.swapaxes(1, 2)
- v_sw = v_val.swapaxes(1, 2)
-
- k_q, k_s, k_b = mx.quantize(k_sw, group_size=group_size, bits=bits)
- v_q, v_s, v_b = mx.quantize(v_sw, group_size=group_size, bits=bits)
-
- k_q_out[..., :S, :] = k_q
- k_s_out[..., :S, :] = k_s
- k_b_out[..., :S, :] = k_b
-
- v_q_out[..., :S, :] = v_q
- v_s_out[..., :S, :] = v_s
- v_b_out[..., :S, :] = v_b
-
- @staticmethod
- def init_cache(
- batch_size: int,
- max_seq_length: int,
- n_heads: int,
- head_dim: int,
- dtype: mx.Dtype,
- *,
- group_size: int = 32,
- bits: int = 8,
- ) -> KVCacheQ:
- packed_dim = KVCacheHNDQuantized._packed_dim(head_dim, bits=bits)
- group_cnt = KVCacheHNDQuantized._group_count(head_dim, group_size=group_size)
-
- packed_shape = (batch_size, n_heads, max_seq_length, packed_dim)
- group_shape = (batch_size, n_heads, max_seq_length, group_cnt)
-
- k_q = mx.zeros(packed_shape, dtype=mx.uint32)
- k_s = mx.zeros(group_shape, dtype=dtype)
- k_b = mx.zeros(group_shape, dtype=dtype)
-
- v_q = mx.zeros(packed_shape, dtype=mx.uint32)
- v_s = mx.zeros(group_shape, dtype=dtype)
- v_b = mx.zeros(group_shape, dtype=dtype)
-
- return (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits)
-
-
class AttentionABC(ABC, nn.Module):
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int, *args, **kwds):
super().__init__()
@@ -271,26 +170,26 @@ class AttentionABC(ABC, nn.Module):
self.kc_class: KVCacheProtocol
@abstractmethod
- def __call__(
- self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array
- ) -> Array: ...
+ def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache, cache_idx: Array, attn_mask: Array) -> Array: ...
- def prefill(self, x: Array, kv_cache: KVCache | KVCacheQ, attn_mask: Array):
+ def prefill(self, x: Array, kv_cache: KVCache, attn_mask: Array):
bsz, seqlen, _ = x.shape
- q, k, v = self.in_proj(x).split(3, axis=-1)
+ qkv = self.in_proj(x)
- q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
+ q, k, v = mx.split(qkv, 3, -1)
+
+ q = q.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3)
+ k = k.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3)
+ v = v.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3)
self.kc_class.prefill_kv(k, v, kv_cache)
- q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
-
attn = mx.fast.scaled_dot_product_attention(q, k, v, mask=attn_mask, scale=self.scale)
attn = mx.nan_to_num(attn)
- attn = attn.swapaxes(1, 2).reshape(1, -1, self.hidden_dim)
+ attn = attn.transpose(0, 2, 1, 3).reshape(bsz, seqlen, -1)
output = self.out_proj(attn)
@@ -321,7 +220,7 @@ class TransformerBlockABC(nn.Module):
self.attention_norm = nn.LayerNorm(self.hidden_dim)
self.ffn_norm = nn.LayerNorm(self.hidden_dim)
- def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
+ def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache, cache_idx: Array, attn_mask: Array):
h = self.attention_norm(
x
+ self.attention(
@@ -335,7 +234,7 @@ class TransformerBlockABC(nn.Module):
out = self.ffn_norm(h + self.feed_forward(h))
return out
- def prefill(self, x: Array, attn_mask: Array, kv_cache: KVCache | KVCacheQ):
+ def prefill(self, x: Array, attn_mask: Array, kv_cache: KVCache):
h = self.attention_norm(
x
+ self.attention.prefill(
@@ -382,7 +281,7 @@ class TransformerDecoderABC(nn.Module):
self,
input_pos: Array,
x: Array,
- kv_caches: MutableSequence[KVCache | KVCacheQ],
+ kv_caches: MutableSequence[KVCache],
cache_idx: Array,
*args,
**kwds,
@@ -399,7 +298,7 @@ class TransformerDecoderABC(nn.Module):
return x
- def prefill(self, x: Array, mask: Array, kv_caches: MutableSequence[KVCache | KVCacheQ]):
+ def prefill(self, x: Array, mask: Array, kv_caches: MutableSequence[KVCache]):
for layer, kv_cache in zip(self.layers, kv_caches):
x = layer.prefill(
x,
@@ -413,7 +312,7 @@ class T2SDecoderABC(nn.Module, T2SDecoderProtocol):
def __init__(
self,
config: dict,
- max_seq_length: int = 2000,
+ max_seq_length: int = 1500,
max_batch_size: int = 10,
) -> None:
super().__init__()
@@ -451,24 +350,27 @@ class T2SDecoderABC(nn.Module, T2SDecoderProtocol):
self.embedding_dim,
scale=False,
max_batch_size=max_batch_size,
- max_seq_len=max_seq_length,
+ max_seq_length=max_seq_length,
)
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim,
scale=False,
max_batch_size=max_batch_size,
- max_seq_len=max_seq_length,
+ max_seq_length=max_seq_length,
)
- self.kv_class: KVCacheProtocol
+ self.kv_class: Type[KVCacheProtocol]
- def init_cache(self, bsz: int = 0, *args, **kwds) -> MutableSequence[KVCache | KVCacheQ]:
+ self.bits: int = -1
+ self.group_size: int = -1
+
+ def init_cache(self, bsz: int = 0, *args, **kwds) -> MutableSequence[KVCache]:
bsz = bsz or self.h.max_batch_size
assert bsz <= self.h.max_batch_size
seq_lens = self.h.max_seq_length
dtype = self.bert_proj.bias.dtype
- cache: MutableSequence[KVCache | KVCacheQ] = [
+ cache: MutableSequence[KVCache] = [
self.kv_class.init_cache(bsz, seq_lens, self.n_head, self.head_dim, dtype, *args, **kwds)
for _ in range(self.n_layer)
]
@@ -503,8 +405,7 @@ class T2SDecoderABC(nn.Module, T2SDecoderProtocol):
return xy_pos
def compile(self):
- setattr(self.h, "__call__", mx.compile(self.h.__call__))
- # setattr(self.h, "prefill", mx.compile(self.h.prefill, shapeless=True))
+ setattr(self.h, "__call__", mx.compile(self.h.__call__, shapeless=True))
def pre_forward(self, session: T2SSessionMLX):
attn_mask = session.attn_mask
@@ -525,4 +426,21 @@ class T2SDecoderABC(nn.Module, T2SDecoderProtocol):
attn_mask = session.attn_mask
input_pos = session.input_pos
attn_mask[mx.arange(session.bsz), :, :, input_pos] = True
- mx.eval(attn_mask)
+
+ def quantize(self, mode: Literal["Affine", "MXFP4"] | None = None) -> None:
+ if mode is None:
+ return
+ if mode not in {"Affine", "MXFP4"}:
+ raise ValueError(f"Unsupported quantization mode: {mode}")
+ match mode:
+ case "Affine":
+ self.bits = 8
+ self.group_size = 32
+ nn.quantize(self.h, group_size=self.group_size, bits=self.bits, mode="affine")
+ case "MXFP4":
+ self.bits = 4
+ self.group_size = 32
+ nn.quantize(self.h, group_size=self.group_size, bits=self.bits, mode="mxfp4")
+
+ case _:
+ raise ValueError(f"Unsupported Quantization Mode for MLX: {mode}")
diff --git a/GPT_SoVITS/Accelerate/PyTorch/__init__.py b/GPT_SoVITS/Accelerate/PyTorch/__init__.py
index 91617fc6..7cef76e7 100644
--- a/GPT_SoVITS/Accelerate/PyTorch/__init__.py
+++ b/GPT_SoVITS/Accelerate/PyTorch/__init__.py
@@ -7,10 +7,23 @@ from .structs import T2SRequest, T2SResult
from .t2s_engine import T2SEngine as T2SEngineTorch
torch.set_grad_enabled(False)
+
+if torch.__version__ >= "2.9.0":
+ torch.backends.fp32_precision = "tf32" # type: ignore
+else:
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+if torch.cuda.is_available():
+ torch.backends.cuda.preferred_blas_library("cublaslt")
+
+
+torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
+torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
+torch.backends.cuda.matmul.allow_fp16_accumulation = True
+torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
-torch.backends.cuda.matmul.allow_tf32 = True
-torch.backends.cudnn.allow_tf32 = True
backends = ["torch_varlen"]
if torch.cuda.is_available():
@@ -30,5 +43,21 @@ if torch.cuda.is_available():
# if torch.mps.is_available():
# backends.append("mps_flash_attn_varlen")
+BLACKWELL = False
+if torch.cuda.is_available():
+ for i in range(torch.cuda.device_count()):
+ major, minor = torch.cuda.get_device_capability(i)
+ sm_version = major + minor / 10.0
+ if sm_version >= 9.0:
+ BLACKWELL = True
-__all__ = ["T2SEngineTorch", "T2SRequest", "sample_naive", "T2SResult", "backends"]
+quantization_methods_torch: list[str | None] = [None]
+if importlib.util.find_spec("torchao") is not None:
+ quantization_methods_torch.append("Int8")
+ if BLACKWELL:
+ quantization_methods_torch.append("FP8")
+if BLACKWELL:
+ quantization_methods_torch.append("FP8_E4M3FN")
+
+
+__all__ = ["T2SEngineTorch", "T2SRequest", "sample_naive", "T2SResult", "backends", "quantization_methods_torch"]
diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py b/GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py
index 6b699198..adf7fceb 100644
--- a/GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py
+++ b/GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py
@@ -100,7 +100,7 @@ class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config,
- max_seq_length=2000,
+ max_seq_length=1500,
max_batch_size=10,
) -> None:
assert torch.cuda.is_available()
diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py b/GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py
index f8b5d0a1..981a71df 100644
--- a/GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py
+++ b/GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py
@@ -78,7 +78,7 @@ class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config,
- max_seq_length=2000,
+ max_seq_length=1500,
max_batch_size=10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py b/GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py
index ddd150c3..cd98739e 100644
--- a/GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py
+++ b/GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py
@@ -94,7 +94,7 @@ class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config,
- max_seq_length=2000,
+ max_seq_length=1500,
max_batch_size=10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py b/GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py
index 961bc19a..37cae95d 100644
--- a/GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py
+++ b/GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py
@@ -78,7 +78,7 @@ class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config,
- max_seq_length=2000,
+ max_seq_length=1500,
max_batch_size=10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py b/GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py
index 3618376e..cd16c79d 100644
--- a/GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py
+++ b/GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py
@@ -86,7 +86,7 @@ class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config,
- max_seq_length=2000,
+ max_seq_length=1500,
max_batch_size=10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
diff --git a/GPT_SoVITS/Accelerate/PyTorch/quantization.py b/GPT_SoVITS/Accelerate/PyTorch/quantization.py
new file mode 100644
index 00000000..aff25cb5
--- /dev/null
+++ b/GPT_SoVITS/Accelerate/PyTorch/quantization.py
@@ -0,0 +1,158 @@
+from typing import cast
+
+import torch
+
+from . import nn
+
+Tensor = torch.Tensor
+
+
+# based on ComfyUI's and MinusZoneAI's fp8_linear optimization
+def fp8_linear_forward(cls: nn.Linear, input: Tensor):
+ weight_dtype = cls.weight.dtype
+ base_dtype = input.dtype
+ if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
+ if len(input.shape) == 3:
+ input_shape = input.shape
+
+ scale_weight: Tensor | None = getattr(cls, "scale_weight", None)
+ if scale_weight is None:
+ scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
+ else:
+ scale_weight = scale_weight.to(input.device).squeeze()
+
+ scale_input = torch.ones((), device=input.device, dtype=torch.float32)
+
+ input = torch.clamp(input, min=-448, max=448, out=input)
+ inn = (
+ input.reshape(-1, input_shape[2]).to(torch.float8_e4m3fn).contiguous()
+ ) # always e4m3fn because e5m2 * e5m2 is not supported
+
+ bias = cls.bias if cls.bias is not None else None
+
+ o = torch._scaled_mm(
+ inn, cls.weight.t(), out_dtype=base_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight
+ )
+
+ return o.reshape((-1, input_shape[1], cls.weight.shape[0]))
+ else:
+ raise
+ else:
+ raise
+
+
+def convert_fp8_linear(
+ module: nn.Module,
+):
+ apply_fn = fp8_linear_forward
+
+ for _, sub in list(module.named_modules()):
+ if isinstance(sub, nn.Linear):
+ if getattr(sub, "_fp8", False):
+ continue
+ setattr(sub, "forward", apply_fn)
+ setattr(sub, "_fp8", True)
+ return module
+
+
+def per_tensor_quantize(tensor: torch.Tensor) -> tuple[Tensor, Tensor]:
+ """Quantize a tensor using per-tensor static scaling factor.
+ Args:
+ tensor: The input tensor.
+ """
+ finfo = torch.finfo(torch.float8_e4m3fn)
+ # Calculate the scale as dtype max divided by absmax.
+ # Since .abs() creates a new tensor, we use aminmax to get
+ # the min and max first and then calculate the absmax.
+ if tensor.numel() == 0:
+ # Deal with empty tensors (triggered by empty MoE experts)
+ min_val, max_val = (
+ torch.tensor(0.0, dtype=tensor.dtype),
+ torch.tensor(1.0, dtype=tensor.dtype),
+ )
+ else:
+ min_val, max_val = tensor.aminmax()
+ amax = min_val.abs().max(max_val.abs())
+ scale = finfo.max / amax.clamp(min=1e-12)
+ # scale and clamp the tensor to bring it to
+ # the representative range of float8 data type
+ # (as default cast is unsaturated)
+ qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
+ # Return both float8 data and the inverse scale (as float),
+ # as both required as inputs to torch._scaled_mm
+ qweight = qweight.to(torch.float8_e4m3fn)
+ scale = scale.float().reciprocal()
+ return qweight, scale
+
+
+def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
+ cuda_compute_capability = torch.cuda.get_device_capability()
+ if cuda_compute_capability >= (9, 0):
+ output, _ = torch._scaled_mm(
+ A,
+ B.t(),
+ out_dtype=out_dtype,
+ scale_a=A_scale,
+ scale_b=B_scale,
+ bias=bias,
+ )
+ else:
+ output = torch.nn.functional.linear(
+ A.to(out_dtype) * A_scale,
+ B.to(out_dtype) * B_scale.to(out_dtype),
+ bias=bias,
+ )
+ return output
+
+
+class FP8DynamicLinear(nn.Module):
+ def __init__(self, qweight: Tensor, scale: Tensor, bias: Tensor):
+ super().__init__()
+ self.weight = torch.nn.Parameter(qweight, requires_grad=False)
+ self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
+ self.bias = bias
+
+ def __call__(self, x):
+ qinput, x_scale = per_tensor_quantize(x)
+ output = fp8_gemm(
+ A=qinput,
+ A_scale=x_scale,
+ B=self.weight,
+ B_scale=self.weight_scale,
+ bias=self.bias,
+ out_dtype=x.dtype,
+ )
+ return output
+
+
+def replace_all_linear_with_fp8(model: nn.Module):
+ """
+ Recursively replace every nn.Linear with FP8DynamicLinear in-place.
+ """
+
+ def _recursively_replace(parent: nn.Module):
+ for child_name, child in list(parent.named_children()):
+ child = cast(nn.Module, child)
+ if isinstance(child, FP8DynamicLinear):
+ continue
+
+ if isinstance(child, nn.Linear):
+ device = child.weight.device
+
+ w = child.weight
+
+ b = child.bias.clone()
+
+ qw, qs = per_tensor_quantize(w)
+
+ quant_linear = FP8DynamicLinear(qw, qs, b)
+
+ quant_linear.to(device)
+
+ setattr(parent, child_name, quant_linear)
+
+ del child
+ else:
+ _recursively_replace(child)
+
+ _recursively_replace(model)
diff --git a/GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py b/GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py
index 0b9eec0c..06bdfa7c 100644
--- a/GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py
+++ b/GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py
@@ -1,20 +1,79 @@
-from typing import Protocol
+from typing import Callable, Protocol, TypeVar, cast
import torch
import torch.nn.functional as F
+from typing_extensions import ParamSpec
+P = ParamSpec("P")
+R = TypeVar("R")
Tensor = torch.Tensor
+def script(fn: Callable[P, R]) -> Callable[P, R]:
+ scripted = torch.jit.script(fn)
+ return cast(Callable[P, R], scripted)
+
+
+@script
+def apply_repetition_penalty(logits: Tensor, previous_tokens: Tensor, repetition_penalty: float):
+ previous_tokens = previous_tokens.long()
+ score = torch.gather(logits, dim=1, index=previous_tokens)
+ score = torch.where(
+ score < 0,
+ score * repetition_penalty,
+ score / repetition_penalty,
+ )
+ logits.scatter_(dim=1, index=previous_tokens, src=score)
+ return logits
+
+
+@script
+def apply_greedy_sampling(logits: Tensor):
+ return torch.argmax(logits, dim=-1, keepdim=True).to(dtype=torch.int32)
+
+
+@script
+def apply_temperature(logits: Tensor, temperature: float):
+ return logits / temperature
+
+
+@script
+def apply_top_k(logits: Tensor, top_k: int):
+ v, _ = torch.topk(logits, top_k)
+ pivot = v[:, -1].unsqueeze(-1)
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
+ return logits
+
+
+@script
+def apply_top_p(logits: Tensor, top_p: float):
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+ cum_probs[cum_probs > 1] = 1
+ sorted_indices_to_remove = cum_probs > top_p
+ sorted_indices_to_remove[:, 0] = False # keep at least one option
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+ return logits
+
+
+@script
+def apply_sampling(logits: Tensor):
+ probs = F.softmax(logits, dim=-1)
+ q = -torch.log(torch.rand_like(probs))
+ idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
+ return idx_next
+
+
class SampleProtocol(Protocol):
@staticmethod
def __call__(
logits: Tensor,
previous_tokens: Tensor,
+ repetition_penalty: float,
temperature: float,
top_k: int,
top_p: float,
- repetition_penalty: float,
) -> Tensor: ...
@@ -23,45 +82,23 @@ class sample_naive(SampleProtocol):
def __call__(
logits: Tensor,
previous_tokens: Tensor,
- temperature: float,
- top_k: int,
- top_p: float,
- repetition_penalty: float,
+ repetition_penalty: float = 1.35,
+ temperature: float = 1.0,
+ top_k: int = 15,
+ top_p: float = 1.0,
):
- if temperature <= 1e-5:
- probs = F.softmax(logits, dim=-1)
- return torch.argmax(probs, dim=-1, keepdim=True).to(dtype=torch.int32)
-
if repetition_penalty != 1.0:
- previous_tokens = previous_tokens.long()
- score = torch.gather(logits, dim=1, index=previous_tokens)
- score = torch.where(
- score < 0,
- score * repetition_penalty,
- score / repetition_penalty,
- )
- logits.scatter_(dim=1, index=previous_tokens, src=score)
+ logits = apply_repetition_penalty(logits, previous_tokens, repetition_penalty)
+
+ if temperature <= 1e-5:
+ return apply_greedy_sampling(logits)
+ elif temperature < 1.0:
+ logits = apply_temperature(logits, temperature)
+
+ if top_k < 1025:
+ logits = apply_top_k(logits, top_k)
if top_p < 1.0:
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
- cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
- cum_probs[cum_probs > 1] = 1
- sorted_indices_to_remove = cum_probs > top_p
- sorted_indices_to_remove[:, 0] = False # keep at least one option
- indices_to_remove = sorted_indices_to_remove.scatter(
- dim=1, index=sorted_indices, src=sorted_indices_to_remove
- )
- logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+ logits = apply_top_p(logits, top_p)
- if temperature < 1.0:
- logits /= temperature
-
- v, _ = torch.topk(logits, top_k)
- pivot = v[:, -1].unsqueeze(-1)
- logits = torch.where(logits < pivot, -float("Inf"), logits)
-
- probs = F.softmax(logits, dim=-1)
- q = -torch.log(torch.rand_like(probs))
- idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
-
- return idx_next
+ return apply_sampling(logits)
diff --git a/GPT_SoVITS/Accelerate/PyTorch/structs.py b/GPT_SoVITS/Accelerate/PyTorch/structs.py
index 1822acdc..fe11a501 100644
--- a/GPT_SoVITS/Accelerate/PyTorch/structs.py
+++ b/GPT_SoVITS/Accelerate/PyTorch/structs.py
@@ -18,6 +18,7 @@ Tensor = torch.Tensor
class T2SResult:
result: list[Tensor] | None = None
infer_speed: tuple[float, float] = (0.0, 0.0)
+ total_tokens: int = 0
status: Literal["Success", "Error"] = "Success"
exception: Optional[Exception] = None
traceback: Optional[str] = None
@@ -66,7 +67,7 @@ class T2SDecoderProtocol(Protocol):
class T2SEngineProtocol(Protocol):
- def _handle_request(self, request: T2SRequest) -> tuple[list[Tensor], float, float]: ...
+ def _handle_request(self, request: T2SRequest) -> tuple[list[Tensor], float, float, int]: ...
def generate(self, request: T2SRequest) -> T2SResult: ...
@@ -107,6 +108,7 @@ class T2SSession:
self.input_pos = torch.zeros_like(self.prefill_len)
self.input_pos.add_(self.prefill_len)
+ self.input_pos.squeeze_(0)
# CUDA Graph
self.stream: Optional[torch.cuda.Stream] = None
diff --git a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py
index ada5b096..89f18bc8 100644
--- a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py
+++ b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py
@@ -1,15 +1,15 @@
import contextlib
-import gc
import os
import sys
import time
import traceback
from importlib import import_module
+from typing import Literal
import torch
from rich.progress import BarColumn, Progress, TextColumn
-from ..logger import SpeedColumnToken, console, logger
+from ..logger import SpeedColumnToken, console, logger, timer
from .structs import T2SEngineProtocol, T2SRequest, T2SResult, T2SSession
from .t2s_model_abc import (
CUDAGraphCacheABC,
@@ -41,12 +41,14 @@ class T2SEngine(T2SEngineProtocol):
decoder = self.decoder_model
session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
batch_idx = torch.arange(session.bsz)
+ debug = request.debug
t1 = 0.0
infer_speed = 0.0
infer_time = 0.0
+ idx = 0
- torch_profiler = TorchProfiler(request.debug)
+ torch_profiler = TorchProfiler(debug)
with (
torch_profiler.profiler(),
Progress(
@@ -59,14 +61,15 @@ class T2SEngine(T2SEngineProtocol):
) as progress,
):
torch_profiler.start()
- max_token = int(min(2000 - session.input_pos.max(), 1500))
+ max_token = min(int(1500 - session.input_pos.max()), 1000)
task = progress.add_task("T2S Decoding", total=max_token)
for idx in range(max_token):
progress.update(task, advance=1)
if idx == 0:
- with torch_profiler.record("Prefill"):
+ with torch_profiler.record("Prefill"), timer("Torch.Prefill", debug=debug):
session.kv_cache = decoder.init_cache(session.bsz)
+ t1 = time.perf_counter()
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
else:
@@ -78,7 +81,7 @@ class T2SEngine(T2SEngineProtocol):
):
self.graphcache.assign_graph(session)
- with torch_profiler.record("Decode"):
+ with torch_profiler.record("Decode"), timer("Torch.Decode", debug=debug):
if session.graph:
assert session.stream
session.stream.wait_stream(torch.cuda.default_stream())
@@ -103,7 +106,7 @@ class T2SEngine(T2SEngineProtocol):
if idx == 0:
logits[:, -1] = float("-inf")
- with torch_profiler.record("Sampling"):
+ with torch_profiler.record("Sampling"), timer("Torch.Sampling", debug=debug):
samples = session.sample(
logits=logits,
previous_tokens=session.y[:, : session.y_len + idx],
@@ -115,7 +118,7 @@ class T2SEngine(T2SEngineProtocol):
session.y[batch_idx, session.y_len + idx] = samples
session.input_pos.add_(1)
- with torch_profiler.record("EOS"):
+ with torch_profiler.record("EOS"), timer("Torch.EOS", debug=debug):
argmax_token = torch.argmax(logits, dim=-1)
sample_token = samples.squeeze(1)
EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
@@ -128,73 +131,48 @@ class T2SEngine(T2SEngineProtocol):
session.y_results[i] = session.y[i, session.y_len : session.y_len + idx]
session.completed[newly_done_indices] = True
- if torch.all(session.completed).item():
- if session.y[:, session.y_len :].sum() == 0:
- session.y_results = [torch.tensor(0) for _ in range(session.bsz)]
- logger.error("Bad Zero Prediction")
- else:
- logger.info(
- f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
- )
- logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
- infer_time = time.perf_counter() - t1
- infer_speed = (idx - 1) / infer_time
- break
+ if torch.all(session.completed).item():
+ logger.info(
+ f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
+ )
+ logger.info(
+ f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s"
+ )
+ infer_time = time.perf_counter() - t1
+ infer_speed = (idx + 1) * session.bsz / infer_time
+ break
- if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
- for i in range(session.bsz):
- if not session.completed[i].item():
- session.y_results[i] = session.y[i, session.y_len : session.y_len + 1499]
- session.completed[i] = True
- logger.error("Bad Full Prediction")
- break
+ if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
+ for i in range(session.bsz):
+ if not session.completed[i].item():
+ session.y_results[i] = session.y[[i], session.y_len : session.y_len + idx]
+ session.completed[i] = True
+ logger.error("Bad Full Prediction")
+ infer_time = time.perf_counter() - t1
+ infer_speed = (idx + 1) * session.bsz / infer_time
+ break
- with torch_profiler.record("NextPos"):
+ with torch_profiler.record("NextPos"), timer("Torch.NextPos", debug=debug):
y_emb = decoder.ar_audio_embedding(samples)
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
- if idx == 1:
- t1 = time.perf_counter()
-
- if idx == 20:
+ if idx == 10:
torch_profiler.end()
- if idx % 100 == 0:
- match session.device.type:
- case "cuda":
- torch.cuda.empty_cache()
- case "mps":
- torch.mps.empty_cache()
- case "xpu":
- torch.xpu.empty_cache()
- case "mtia":
- torch.mtia.empty_cache()
- case "cpu":
- pass
-
- match session.device.type:
- case "cuda":
- if session.stream is not None:
- torch.cuda.current_stream().wait_stream(session.stream)
- torch.cuda.empty_cache()
- case "mps":
- torch.mps.empty_cache()
- case "xpu":
- torch.xpu.empty_cache()
- case "mtia":
- torch.mtia.empty_cache()
- case "cpu":
- gc.collect(1)
-
if request.use_cuda_graph and self.graphcache.is_applicable:
self.graphcache.release_graph(session)
- return session.y_results[: request.valid_length], infer_speed, infer_time
+ return session.y_results[: request.valid_length], infer_speed, infer_time, (idx + 1) * session.bsz
def generate(self, request: T2SRequest):
try:
- result, infer_speed, infer_time = self._handle_request(request)
- t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success")
+ result, infer_speed, infer_time, total_tokens = self._handle_request(request)
+ t2s_result = T2SResult(
+ result=result,
+ infer_speed=(infer_speed, infer_time),
+ total_tokens=total_tokens,
+ status="Success",
+ )
except Exception as e:
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
if self.decoder_model.compiled:
@@ -203,7 +181,12 @@ class T2SEngine(T2SEngineProtocol):
return t2s_result
@staticmethod
- def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "Flash-Attn-Varlen-CUDAGraph"):
+ def load_decoder(
+ weights_path: os.PathLike,
+ max_batch_size: int = 1,
+ backend: str = "Flash-Attn-Varlen-CUDAGraph",
+ quantize_mode: Literal["Int8", "FP8", "FP8_E4M3FN"] | None = None,
+ ) -> T2SDecoderABC:
logger.info(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend")
module_path = f".backends.{backend.lower().replace('-', '_').replace('cudagraph', 'cuda_graph')}"
decoder_cls_name = "T2SDecoder"
@@ -215,6 +198,10 @@ class T2SEngine(T2SEngineProtocol):
state_dict = dict_s1["weight"]
decoder.load_state_dict(state_dict)
+ if quantize_mode is not None:
+ decoder.quantize(quantize_mode)
+ logger.info(f"Quantized by {quantize_mode} Quantization")
+
return decoder.eval()
def init_cache(self):
diff --git a/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py b/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py
index 7b46b6ab..6c8b51f5 100644
--- a/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py
+++ b/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py
@@ -13,7 +13,7 @@ import time
from abc import ABC, abstractmethod
from contextlib import nullcontext
from pathlib import Path
-from typing import MutableSequence
+from typing import Literal, MutableSequence
import torch
import torch._inductor.config
@@ -24,6 +24,7 @@ from torch.profiler import ExecutionTraceObserver, ProfilerAction, tensorboard_t
from tools.my_utils import get_machine_id
from . import nn
+from .quantization import replace_all_linear_with_fp8
from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession
Tensor = torch.Tensor
@@ -61,26 +62,26 @@ class SinePositionalEmbedding(nn.Module):
scale: bool = False,
alpha: bool = False,
max_batch_size: int = 10,
- max_seq_len: int = 2000,
+ max_seq_length: int = 1500,
):
super().__init__()
self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
self.max_batch_size = max_batch_size
- self.max_seq_len = max_seq_len
+ self.max_seq_length = max_seq_length
self.reverse = False
- self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False)
+ self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_length, embedding_dim), persistent=False)
self.pe: torch.Tensor
self.compute_pe()
def compute_pe(self):
"""Reset the positional encodings."""
if self.reverse:
- position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
+ position = torch.arange(self.max_seq_length - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
else:
- position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
+ position = torch.arange(self.max_seq_length, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
)
@@ -423,7 +424,7 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
def __init__(
self,
config: dict,
- max_seq_length: int = 2000,
+ max_seq_length: int = 1500,
max_batch_size: int = 10,
) -> None:
super().__init__()
@@ -467,7 +468,7 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
scale=False,
alpha=True,
max_batch_size=max_batch_size,
- max_seq_len=max_seq_length,
+ max_seq_length=max_seq_length,
)
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
self.ar_audio_position = SinePositionalEmbedding(
@@ -475,9 +476,12 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
scale=False,
alpha=True,
max_batch_size=max_batch_size,
- max_seq_len=max_seq_length,
+ max_seq_length=max_seq_length,
)
+ self.bits: int
+ self.group_size: int
+
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
@@ -608,6 +612,32 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
def post_forward(self, idx: int, session: T2SSession) -> None:
return
+ def quantize(self, mode: Literal["Int8", "FP8", "FP8_E4M3FN"] | None = None) -> None:
+ if mode is None:
+ return
+ if mode not in {"Int8", "FP8", "FP8_E4M3FN"}:
+ raise ValueError(f"Unsupported quantization mode: {mode}")
+ match mode:
+ case "Int8":
+ self.bits = 8
+ self.group_size = 32
+ import torchao
+
+ torchao.quantization.quantize_(self.h, torchao.quantization.Int8WeightOnlyConfig(self.group_size))
+
+ case "FP8":
+ self.bits = 8
+ import torchao
+
+ torchao.quantization.quantize_(self.h, torchao.quantization.Float8WeightOnlyConfig())
+
+ case "FP8_E4M3FN":
+ self.bits = 8
+ replace_all_linear_with_fp8(self.h)
+
+ case _:
+ raise ValueError(f"Unsupported Quantization Mode for PyTorch: {mode}")
+
class CUDAGraphCacheABC(ABC):
def __init__(
@@ -662,13 +692,13 @@ class CUDAGraphCacheABC(ABC):
class TorchProfiler:
- def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
- self.debug = debug
- self.log_dir = log_dir + str(time.time())
+ def __init__(self, debug: bool, log_dir: str = "./profiler/torch") -> None:
+ self.debug = debug and os.environ.get("TORCH_PROFILER") == "1"
+ self.log_dir = log_dir + "/" + str(time.time())
self.__profiler: torch.profiler.profile
if self.debug and not os.path.exists(self.log_dir):
- os.makedirs(self.log_dir)
+ os.makedirs(self.log_dir, exist_ok=True)
self.tensorboard_handler = tensorboard_trace_handler(self.log_dir)
diff --git a/GPT_SoVITS/Accelerate/__init__.py b/GPT_SoVITS/Accelerate/__init__.py
index 797fe1d0..5508f6d0 100644
--- a/GPT_SoVITS/Accelerate/__init__.py
+++ b/GPT_SoVITS/Accelerate/__init__.py
@@ -1,18 +1,13 @@
from . import MLX, PyTorch
from .logger import console, logger, tb
-from .PyTorch import T2SEngineTorch, T2SRequest, T2SResult
+from .MLX import quantization_methods_mlx
+from .PyTorch import T2SEngineTorch, T2SRequest, T2SResult, quantization_methods_torch
from .PyTorch.structs import T2SEngineProtocol
backends = PyTorch.backends + MLX.backends
backends = [
- b.replace("_", "-")
- .title()
- .replace("Mlx", "MLX")
- .replace("Mps", "MPS")
- .replace("Cuda", "CUDA")
- .replace("Mxfp4", "MXFP4")
- for b in backends
+ b.replace("_", "-").title().replace("Mlx", "MLX").replace("Mps", "MPS").replace("Cuda", "CUDA") for b in backends
]
@@ -27,4 +22,6 @@ __all__ = [
"console",
"tb",
"T2SEngineProtocol",
+ "quantization_methods_torch",
+ "quantization_methods_mlx",
]
diff --git a/GPT_SoVITS/Accelerate/logger.py b/GPT_SoVITS/Accelerate/logger.py
index 021e72e3..bde6f230 100644
--- a/GPT_SoVITS/Accelerate/logger.py
+++ b/GPT_SoVITS/Accelerate/logger.py
@@ -1,4 +1,7 @@
import sys
+import time
+from collections import defaultdict
+from contextlib import nullcontext
from typing import Optional
from loguru import logger
@@ -201,3 +204,46 @@ if __name__ == "__main__":
raise RuntimeError()
except Exception:
logger.bind(show_locals=False).exception("TEST")
+
+
+class Timer:
+ def __init__(self):
+ self.records: dict[str, list[float]] = defaultdict(list)
+ self._stack: list[tuple[str, int]] = []
+
+ def __call__(self, category: str, debug=False):
+ timer = self
+
+ class _Ctx:
+ def __enter__(self):
+ timer._stack.append((category, time.perf_counter_ns()))
+ return timer # 如需在with块里调用timer方法
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ end = time.perf_counter_ns()
+ if not timer._stack:
+ raise RuntimeError("Timer stack underflow: __exit__ without matching __enter__")
+ cat, start = timer._stack.pop()
+ if cat != category:
+ raise RuntimeError(f"Mismatched timer context: expected '{cat}', got '{category}'")
+ elapsed_sec = (end - start) / 1e9
+ timer.records[cat].append(elapsed_sec)
+ return False
+
+ if debug:
+ return _Ctx()
+ else:
+ return nullcontext()
+
+ def clear(self):
+ self.records.clear()
+ self._stack.clear()
+
+ def summary(self):
+ for cat, times in self.records.items():
+ total = sum(times)
+ avg = total / len(times) if times else 0.0
+ print(f"{cat}: count={len(times)}, total={total:.6f}s, avg={avg:.6f}s")
+
+
+timer = Timer()
diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py
index b5150c3f..904affbb 100644
--- a/GPT_SoVITS/inference_webui.py
+++ b/GPT_SoVITS/inference_webui.py
@@ -51,13 +51,11 @@ warnings.filterwarnings(
)
warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*")
-logging.getLogger("markdown_it").setLevel(logging.ERROR)
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("httpcore").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("asyncio").setLevel(logging.ERROR)
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
-logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -88,6 +86,12 @@ def lang_type(text: str) -> str:
return "Auto"
+def none_or_str(value: str):
+ if value == "None":
+ return None
+ return value
+
+
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
prog="inference_webui",
@@ -108,6 +112,15 @@ def build_parser() -> argparse.ArgumentParser:
help="AR Inference Backend",
required=False,
)
+ p.add_argument(
+ "--quantization",
+ "-q",
+ default="None",
+ choices=MLX.quantization_methods_mlx + PyTorch.quantization_methods_torch,
+ type=none_or_str,
+ help="Quantization Method",
+ required=False,
+ )
p.add_argument(
"--device",
"-d",
@@ -393,10 +406,9 @@ with contextlib.suppress(UnboundLocalError):
def change_gpt_weights(gpt_path):
global t2s_engine, config
-
if "mlx" in ar_backend.lower():
t2s_engine = MLX.T2SEngineMLX(
- MLX.T2SEngineMLX.load_decoder(Path(gpt_path), backend=ar_backend),
+ MLX.T2SEngineMLX.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization),
"mx.gpu" if infer_device.type != "cpu" else "mx.cpu",
dtype=dtype,
)
@@ -404,7 +416,7 @@ def change_gpt_weights(gpt_path):
total = sum((p[-1].size for p in mxutils.tree_flatten(t2s_engine.decoder_model.parameters()))) # type: ignore
else:
t2s_engine = PyTorch.T2SEngineTorch(
- PyTorch.T2SEngineTorch.load_decoder(Path(gpt_path), backend=ar_backend),
+ PyTorch.T2SEngineTorch.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization),
device,
dtype=dtype,
)
@@ -824,7 +836,7 @@ def get_tts_wav(
pred_semantic_list = t2s_result.result
assert pred_semantic_list, t2s_result.traceback
pred_semantic = pred_semantic_list[0].unsqueeze(0).to(infer_device)
- infer_len.append(pred_semantic.shape[-1])
+ infer_len.append(t2s_result.total_tokens)
infer_time.append(t2s_result.infer_speed[-1])
cache[i_text] = pred_semantic
diff --git a/GPT_SoVITS/text/chinese2.py b/GPT_SoVITS/text/chinese2.py
index 7ec03e77..ec1e0dd2 100644
--- a/GPT_SoVITS/text/chinese2.py
+++ b/GPT_SoVITS/text/chinese2.py
@@ -30,6 +30,7 @@ pinyin_to_symbol_map = {
parent_directory = os.path.dirname(current_file_path)
is_g2pw = os.getenv("G2PW", "1") == "1"
+debug = os.getenv("DEBUG", "0") == "1"
if is_g2pw:
g2pw = G2PWPinyin(
model_dir="GPT_SoVITS/text/G2PWModel",
@@ -202,7 +203,8 @@ def _g2p(segments):
# assert len(sub_initials) == len(sub_finals) == len(word)
initials = sum(initials, [])
finals = sum(finals, [])
- print("pypinyin结果", initials, finals)
+ if debug:
+ print("pypinyin结果", initials, finals)
else:
# g2pw采用整句推理
pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3)
diff --git a/README.md b/README.md
index 33c442e9..558c1edf 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,7 @@ A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.
[](https://github.com/RVC-Boss/gpt-sovits/releases)
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
-[](https://lj1995-gpt-sovits-proplus.hf.space/)
+[](https://lj1995-gpt-sovits-proplus.hf.space/)
[](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
[](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
@@ -57,7 +57,7 @@ Unseen speakers few-shot fine-tuning demo:
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
| RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph |
| RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph |
-| Apple M4 | 0.21 | UNK | 1 | MLX Quantized Affined |
+| Apple M4 | 0.16 | UNK | 1 | MLX Varlen |
diff --git a/docs/cn/README.md b/docs/cn/README.md
index 2f03b3ca..12e746d6 100644
--- a/docs/cn/README.md
+++ b/docs/cn/README.md
@@ -13,7 +13,7 @@
[](https://github.com/RVC-Boss/gpt-sovits/releases)
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
-[](https://lj1995-gpt-sovits-proplus.hf.space/)
+[](https://lj1995-gpt-sovits-proplus.hf.space/)
[](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
[](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
@@ -57,7 +57,7 @@
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
| RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph |
| RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph |
-| Apple M4 | 0.21 | UNK | 1 | MLX Quantized Affined |
+| Apple M4 | 0.16 | UNK | 1 | MLX Varlen |
diff --git a/docs/ja/README.md b/docs/ja/README.md
index c41f6aed..4f064692 100644
--- a/docs/ja/README.md
+++ b/docs/ja/README.md
@@ -13,7 +13,7 @@
[](https://github.com/RVC-Boss/gpt-sovits/releases)
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
-[](https://lj1995-gpt-sovits-proplus.hf.space/)
+[](https://lj1995-gpt-sovits-proplus.hf.space/)
[](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
[](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
@@ -57,7 +57,7 @@
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
| RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph |
| RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph |
-| Apple M4 | 0.21 | UNK | 1 | MLX Quantized Affined |
+| Apple M4 | 0.16 | UNK | 1 | MLX Varlen |
diff --git a/docs/ko/README.md b/docs/ko/README.md
index 28012977..62b09db1 100644
--- a/docs/ko/README.md
+++ b/docs/ko/README.md
@@ -13,7 +13,7 @@
[](https://github.com/RVC-Boss/gpt-sovits/releases)
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
-[](https://lj1995-gpt-sovits-proplus.hf.space/)
+[](https://lj1995-gpt-sovits-proplus.hf.space/)
[](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
[](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
@@ -57,7 +57,7 @@
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
| RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph |
| RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph |
-| Apple M4 | 0.21 | UNK | 1 | MLX Quantized Affined |
+| Apple M4 | 0.16 | UNK | 1 | MLX Varlen |
diff --git a/docs/tr/README.md b/docs/tr/README.md
index de94d5bd..8bad35cb 100644
--- a/docs/tr/README.md
+++ b/docs/tr/README.md
@@ -13,7 +13,7 @@ Güçlü Birkaç Örnekli Ses Dönüştürme ve Metinden Konuşmaya Web Arayüz
[](https://github.com/RVC-Boss/gpt-sovits/releases)
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
-[](https://lj1995-gpt-sovits-proplus.hf.space/)
+[](https://lj1995-gpt-sovits-proplus.hf.space/)
[](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
[](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
@@ -57,7 +57,7 @@ Görünmeyen konuşmacılar birkaç örnekli ince ayar demosu:
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
| RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph |
| RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph |
-| Apple M4 | 0.21 | UNK | 1 | MLX Quantized Affined |
+| Apple M4 | 0.16 | UNK | 1 | MLX Varlen |
diff --git a/install.ps1 b/install.ps1
index d993e97b..4646a60e 100644
--- a/install.ps1
+++ b/install.ps1
@@ -225,7 +225,7 @@ switch ($Device) {
Write-Warning "CUDA 12.8 Is Not Supported By Current Driver"
}
Write-Info "Installing PyTorch For CUDA 12.8..."
- Invoke-Pip torch torchaudio --index-url "https://download.pytorch.org/whl/cu128"
+ Invoke-Pip torch torchao --index-url "https://download.pytorch.org/whl/cu128"
Invoke-Conda cuda-nvcc=12.8
Invoke-Pip psutil ninja packaging wheel "setuptools>=42"
Write-Info "Installing Flash Attn..."
@@ -240,7 +240,7 @@ switch ($Device) {
Write-Warning "CUDA 12.6 Is Not Supported By Current Driver"
}
Write-Info "Installing PyTorch For CUDA 12.6..."
- Invoke-Pip torch torchaudio --index-url "https://download.pytorch.org/whl/cu126"
+ Invoke-Pip torch torchao --index-url "https://download.pytorch.org/whl/cu126"
Invoke-Conda cuda-nvcc=12.6
Invoke-Pip psutil ninja packaging wheel "setuptools>=42"
Write-Info "Installing Flash Attn..."
@@ -249,7 +249,7 @@ switch ($Device) {
}
"CPU" {
Write-Info "Installing PyTorch For CPU..."
- Invoke-Pip torch torchaudio --index-url "https://download.pytorch.org/whl/cpu"
+ Invoke-Pip torch torchao --index-url "https://download.pytorch.org/whl/cpu"
}
}
Write-Success "PyTorch Installed"
diff --git a/install.sh b/install.sh
index a5cce698..e7d298af 100644
--- a/install.sh
+++ b/install.sh
@@ -334,14 +334,14 @@ if [ "$USE_CUDA" = true ] && [ "$WORKFLOW" = false ]; then
echo -r "${WARNING}CUDA 12.8 Is Not Supported By Current Driver"
fi
echo -e "${INFO}Installing PyTorch For CUDA 12.8..."
- run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cu128"
+ run_pip_quiet torch torchao --index-url "https://download.pytorch.org/whl/cu128"
run_conda_quiet cuda-nvcc=12.8
elif [ "$CUDA" = 126 ]; then
if awk "BEGIN {exit !($CUDAVERSION < 12.6)}"; then
echo -r "${WARNING}CUDA 12.6 Is Not Supported By Current Driver"
fi
echo -e "${INFO}Installing PyTorch For CUDA 12.6..."
- run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cu126"
+ run_pip_quiet torch torchao --index-url "https://download.pytorch.org/whl/cu126"
run_conda_quiet cuda-nvcc=12.6
fi
echo -e "${INFO}Installing Flash Attn"
@@ -350,14 +350,14 @@ if [ "$USE_CUDA" = true ] && [ "$WORKFLOW" = false ]; then
echo -e "${SUCCESS}Flash Attn Installed"
elif [ "$USE_MLX" = true ] && [ "$WORKFLOW" = false ]; then
echo -e "${INFO}Installing MLX & PyTorch For MPS..."
- run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cpu"
+ run_pip_quiet torch torchao --index-url "https://download.pytorch.org/whl/cpu"
run_pip_quiet mlx
elif [ "$USE_ROCM" = true ] && [ "$WORKFLOW" = false ]; then
echo -e "${INFO}Installing PyTorch For ROCm 6.2..."
- run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/rocm6.2"
+ run_pip_quiet torch torchao --index-url "https://download.pytorch.org/whl/rocm6.2"
elif [ "$USE_CPU" = true ] && [ "$WORKFLOW" = false ]; then
echo -e "${INFO}Installing PyTorch For CPU..."
- run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cpu"
+ run_pip_quiet torch torchao --index-url "https://download.pytorch.org/whl/cpu"
elif [ "$WORKFLOW" = false ]; then
echo -e "${ERROR}Unknown Err"
exit 1
diff --git a/requirements.txt b/requirements.txt
index 0875ffa0..37ab0bce 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -13,6 +13,7 @@ peft
py-cpuinfo
pypinyin
split-lang
+torchao
torchaudio
torchcodec
transformers
diff --git a/test.py b/test.py
new file mode 100644
index 00000000..fd4b7127
--- /dev/null
+++ b/test.py
@@ -0,0 +1,879 @@
+import argparse
+import contextlib
+import gc
+import logging
+import os
+import re
+import traceback
+import warnings
+from functools import partial
+from pathlib import Path
+from time import perf_counter as ttime
+from typing import Any
+
+import gradio as gr
+import librosa
+import numpy as np
+import torch
+import torchaudio
+from transformers import AutoModelForMaskedLM, AutoTokenizer
+
+from config import (
+ change_choices,
+ get_dtype,
+ get_weights_names,
+ pretrained_sovits_name,
+)
+from config import (
+ infer_device as default_device,
+)
+from GPT_SoVITS.Accelerate import MLX, PyTorch, T2SEngineProtocol, T2SRequest, backends
+from GPT_SoVITS.Accelerate.logger import console, timer
+from GPT_SoVITS.feature_extractor import cnhubert
+from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch
+from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3
+from GPT_SoVITS.process_ckpt import inspect_version
+from GPT_SoVITS.text import cleaned_text_to_sequence
+from GPT_SoVITS.text.cleaner import clean_text
+from GPT_SoVITS.text.LangSegmenter import LangSegmenter
+from tools.i18n.i18n import I18nAuto, scan_language_list
+from tools.my_utils import DictToAttrRecursive
+
+warnings.filterwarnings(
+ "ignore", message="MPS: The constant padding of more than 3 dimensions is not currently supported natively."
+)
+warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*")
+
+logging.getLogger("urllib3").setLevel(logging.ERROR)
+logging.getLogger("httpcore").setLevel(logging.ERROR)
+logging.getLogger("httpx").setLevel(logging.ERROR)
+logging.getLogger("asyncio").setLevel(logging.ERROR)
+logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
+logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
+
+
+_LANG_RE = re.compile(r"^[a-z]{2}[_-][A-Z]{2}$")
+
+
+def lang_type(text: str) -> str:
+ if text == "Auto":
+ return text
+ if not _LANG_RE.match(text):
+ raise argparse.ArgumentTypeError(f"Unspported Format: {text}, Expected ll_CC/ll-CC")
+ ll, cc = re.split(r"[_-]", text)
+ language = f"{ll}_{cc}"
+ if language in scan_language_list():
+ return language
+ else:
+ return "Auto"
+
+
+def none_or_str(value: str):
+ if value == "None":
+ return None
+ return value
+
+
+def build_parser() -> argparse.ArgumentParser:
+ p = argparse.ArgumentParser(
+ prog="inference_webui",
+ description=f"python -s -m GPT_SoVITS.inference_webui zh_CN -b {backends[-1]}",
+ )
+ p.add_argument(
+ "language",
+ nargs="?",
+ default="Auto",
+ type=lang_type,
+ help="Language Code, Such as zh_CN, en-US",
+ )
+ p.add_argument(
+ "--backends",
+ "-b",
+ choices=backends,
+ default=backends[-1],
+ help="AR Inference Backend",
+ required=False,
+ )
+ p.add_argument(
+ "--quantization",
+ "-q",
+ default="None",
+ choices=MLX.quantization_methods_mlx + PyTorch.quantization_methods_torch,
+ type=none_or_str,
+ help="Quantization Method",
+ required=False,
+ )
+ p.add_argument(
+ "--device",
+ "-d",
+ default=str(default_device),
+ help="Inference Device",
+ required=False,
+ )
+ p.add_argument(
+ "--port",
+ "-p",
+ default=9872,
+ type=int,
+ help="WebUI Binding Port",
+ required=False,
+ )
+ p.add_argument(
+ "--share",
+ "-s",
+ default=False,
+ action="store_true",
+ help="Gradio Share Link",
+ required=False,
+ )
+ p.add_argument(
+ "--cnhubert",
+ default="GPT_SoVITS/pretrained_models/chinese-hubert-base",
+ help="CNHuBERT Pretrain",
+ required=False,
+ )
+ p.add_argument(
+ "--bert",
+ default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
+ help="BERT Pretrain",
+ required=False,
+ )
+ p.add_argument(
+ "--gpt",
+ default="",
+ help="GPT Model",
+ required=False,
+ )
+ p.add_argument(
+ "--sovits",
+ default="",
+ help="SoVITS Model",
+ required=False,
+ )
+
+ return p
+
+
+args = build_parser().parse_args()
+
+hps: Any = None
+vq_model: SynthesizerTrn | SynthesizerTrnV3 | None = None
+t2s_engine: T2SEngineProtocol | None = None
+
+version = model_version = "v2"
+path_sovits_v3 = pretrained_sovits_name["v3"]
+path_sovits_v4 = pretrained_sovits_name["v4"]
+is_exist_s2gv3 = os.path.exists(path_sovits_v3)
+is_exist_s2gv4 = os.path.exists(path_sovits_v4)
+
+cnhubert_base_path = str(args.cnhubert)
+bert_path = str(args.bert)
+infer_ttswebui = int(args.port)
+is_share = bool(args.share)
+
+
+i18n = I18nAuto(language=args.language)
+ar_backend: str = args.backends
+change_choices_i18n = partial(change_choices, i18n=i18n)
+
+SoVITS_names, GPT_names = get_weights_names(i18n)
+
+
+dict_language_v1 = {
+ i18n("中文"): "all_zh", # 全部按中文识别
+ i18n("英文"): "en", # 全部按英文识别
+ i18n("日文"): "all_ja", # 全部按日文识别
+ i18n("中英混合"): "zh", # 按中英混合识别
+ i18n("日英混合"): "ja", # 按日英混合识别
+ i18n("多语种混合"): "auto", # 多语种启动切分识别语种
+}
+dict_language_v2 = {
+ i18n("中文"): "all_zh", # 全部按中文识别
+ i18n("英文"): "en", # 全部按英文识别
+ i18n("日文"): "all_ja", # 全部按日文识别
+ i18n("粤语"): "all_yue", # 全部按粤语识别
+ i18n("韩文"): "all_ko", # 全部按韩文识别
+ i18n("中英混合"): "zh",
+ i18n("日英混合"): "ja",
+ i18n("粤英混合"): "yue",
+ i18n("韩英混合"): "ko",
+ i18n("多语种混合"): "auto", # 多语种启动切分识别语种
+ i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
+}
+dict_language = dict_language_v1 if version == "v1" else dict_language_v2
+
+punctuation = set(["!", "?", "…", ",", ".", "-", " "])
+splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…"}
+v3v4set = {"v3", "v4"}
+
+infer_device = torch.device(args.device)
+device = infer_device if infer_device.type == "cuda" else torch.device("cpu")
+
+dtype = get_dtype(device.index)
+is_half = dtype == torch.float16
+
+tokenizer = AutoTokenizer.from_pretrained(bert_path)
+bert_model = AutoModelForMaskedLM.from_pretrained(bert_path).to(infer_device, dtype)
+
+cnhubert.cnhubert_base_path = cnhubert_base_path
+ssl_model = cnhubert.get_model().to(infer_device, dtype)
+
+
+def mel_fn(x):
+ return mel_spectrogram_torch(
+ y=x,
+ n_fft=1024,
+ num_mels=100,
+ sampling_rate=24000,
+ hop_size=256,
+ win_size=1024,
+ fmin=0,
+ fmax=None,
+ center=False,
+ )
+
+
+gpt_path = str(args.gpt) or GPT_names[0][-1]
+sovits_path = str(args.sovits) or SoVITS_names[0][-1]
+
+
+def get_bert_feature(text, word2ph):
+ inputs = tokenizer(text, return_tensors="pt")
+ for i in inputs:
+ inputs[i] = inputs[i].to(infer_device)
+ res = bert_model(**inputs, output_hidden_states=True)
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
+
+ assert len(word2ph) == len(text)
+ phone_level_feature = []
+ for i in range(len(word2ph)):
+ repeat_feature = res[i].repeat(word2ph[i], 1)
+ phone_level_feature.append(repeat_feature)
+ phone_level_feature_t = torch.cat(phone_level_feature, dim=0)
+ return phone_level_feature_t.T
+
+
+def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
+ global vq_model, hps, version, model_version, dict_language
+ model_version, version, is_lora, hps, dict_s2 = inspect_version(sovits_path)
+ is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
+ path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
+ if is_lora is True and is_exist is False:
+ info = f"{path_sovits} SoVITS {model_version} {i18n('底模缺失,无法加载相应 LoRA 权重')}"
+ gr.Warning(info)
+ raise FileNotFoundError(info)
+ dict_language = dict_language_v1 if version == "v1" else dict_language_v2
+ visible_sample_steps = visible_inp_refs = None
+ if prompt_language is not None and text_language is not None:
+ if prompt_language in list(dict_language.keys()):
+ prompt_text_update, prompt_language_update = gr.skip(), gr.update(choices=list(dict_language.keys()))
+ else:
+ prompt_text_update = gr.update(value="")
+ prompt_language_update = gr.update(value=i18n("中文"), choices=list(dict_language.keys()))
+ if text_language in list(dict_language.keys()):
+ text_update, text_language_update = gr.skip(), gr.skip()
+ else:
+ text_update = gr.update(value="")
+ text_language_update = gr.update(value=i18n("中文"), choices=list(dict_language.keys()))
+
+ if model_version in v3v4set:
+ visible_sample_steps = True
+ visible_inp_refs = False
+ else:
+ visible_sample_steps = False
+ visible_inp_refs = True
+ yield (
+ prompt_text_update,
+ prompt_language_update,
+ text_update,
+ text_language_update,
+ gr.update(
+ visible=visible_sample_steps,
+ value=32 if model_version == "v3" else 8,
+ choices=[4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32],
+ ),
+ gr.update(visible=visible_inp_refs),
+ gr.update(value=False, interactive=True if model_version not in v3v4set else False),
+ gr.update(visible=True if model_version == "v3" else False),
+ gr.update(value=i18n("模型加载中,请等待"), interactive=False),
+ )
+
+ hps = DictToAttrRecursive(hps)
+ hps.model.semantic_frame_rate = "25hz"
+ hps.model.version = model_version
+ if model_version not in v3v4set:
+ vq_model = SynthesizerTrn(
+ hps.data.filter_length // 2 + 1,
+ hps.train.segment_size // hps.data.hop_length,
+ n_speakers=hps.data.n_speakers,
+ **hps.model,
+ )
+ else:
+ vq_model = SynthesizerTrnV3(
+ hps.data.filter_length // 2 + 1,
+ hps.train.segment_size // hps.data.hop_length,
+ n_speakers=hps.data.n_speakers,
+ **hps.model,
+ ).eval()
+
+ if "pretrained" not in sovits_path:
+ if hasattr(vq_model, "enc_q"):
+ del vq_model.enc_q
+
+ vq_model.load_state_dict(dict_s2["weight"], strict=False)
+
+ vq_model = vq_model.to(infer_device, dtype)
+
+ yield (
+ gr.skip(),
+ gr.skip(),
+ gr.skip(),
+ gr.skip(),
+ gr.skip(),
+ gr.skip(),
+ gr.skip(),
+ gr.skip(),
+ gr.update(value=i18n("合成语音"), interactive=True),
+ )
+
+
+with contextlib.suppress(UnboundLocalError):
+ next(change_sovits_weights(sovits_path))
+
+
+def change_gpt_weights(gpt_path):
+ global t2s_engine, config
+ if "mlx" in ar_backend.lower():
+ t2s_engine = MLX.T2SEngineMLX(
+ MLX.T2SEngineMLX.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization),
+ "mx.gpu" if infer_device.type != "cpu" else "mx.cpu",
+ dtype=dtype,
+ )
+ # t2s_engine.decoder_model.compile()
+ else:
+ t2s_engine = PyTorch.T2SEngineTorch(
+ PyTorch.T2SEngineTorch.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization),
+ device,
+ dtype=dtype,
+ )
+ # t2s_engine.decoder_model.compile()
+
+
+change_gpt_weights(gpt_path)
+
+resample_transform_dict = {}
+
+
+def resample(audio_tensor, sr0, sr1, device):
+ global resample_transform_dict
+ key = f"{sr0}-{sr1}-{device}"
+ if key not in resample_transform_dict:
+ resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
+ return resample_transform_dict[key](audio_tensor)
+
+
+def get_spepc(hps, filename, dtype, device, is_v2pro=False):
+ sr1 = int(hps.data.sampling_rate)
+ audio, sr0 = torchaudio.load_with_torchcodec(filename)
+ audio = audio.to(device)
+
+ if sr0 != sr1:
+ audio = resample(audio, sr0, sr1, device)
+ if audio.shape[0] > 1:
+ audio = audio.mean(0).unsqueeze(0)
+
+ maxx = float(audio.abs().max())
+ if maxx > 1:
+ audio /= min(2, maxx)
+ spec = spectrogram_torch(
+ audio,
+ hps.data.filter_length,
+ hps.data.sampling_rate,
+ hps.data.hop_length,
+ hps.data.win_length,
+ center=False,
+ )
+ spec = spec.to(dtype)
+ if is_v2pro is True:
+ audio = resample(audio, sr1, 16000, device).to(dtype)
+ return spec, audio
+
+
+def clean_text_inf(text, language, version):
+ language = language.replace("all_", "")
+ phones, word2ph, norm_text = clean_text(text, language, version)
+ phones = cleaned_text_to_sequence(phones, version)
+ return phones, word2ph, norm_text
+
+
+def get_bert_inf(phones, word2ph, norm_text, language):
+ language = language.replace("all_", "")
+ if language == "zh":
+ bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype)
+ else:
+ bert = torch.zeros(
+ (1024, len(phones)),
+ dtype=torch.float16 if is_half is True else torch.float32,
+ ).to(device)
+
+ return bert
+
+
+def get_first(text):
+ pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
+ text = re.split(pattern, text)[0].strip()
+ return text
+
+
+def get_phones_and_bert(text, language, version, final=False):
+ text = re.sub(r" {2,}", " ", text)
+ textlist = []
+ langlist = []
+ if language == "all_zh":
+ for tmp in LangSegmenter.getTexts(text, "zh"):
+ langlist.append(tmp["lang"])
+ textlist.append(tmp["text"])
+ elif language == "all_yue":
+ for tmp in LangSegmenter.getTexts(text, "zh"):
+ if tmp["lang"] == "zh":
+ tmp["lang"] = "yue"
+ langlist.append(tmp["lang"])
+ textlist.append(tmp["text"])
+ elif language == "all_ja":
+ for tmp in LangSegmenter.getTexts(text, "ja"):
+ langlist.append(tmp["lang"])
+ textlist.append(tmp["text"])
+ elif language == "all_ko":
+ for tmp in LangSegmenter.getTexts(text, "ko"):
+ langlist.append(tmp["lang"])
+ textlist.append(tmp["text"])
+ elif language == "en":
+ langlist.append("en")
+ textlist.append(text)
+ elif language == "auto":
+ for tmp in LangSegmenter.getTexts(text):
+ langlist.append(tmp["lang"])
+ textlist.append(tmp["text"])
+ elif language == "auto_yue":
+ for tmp in LangSegmenter.getTexts(text):
+ if tmp["lang"] == "zh":
+ tmp["lang"] = "yue"
+ langlist.append(tmp["lang"])
+ textlist.append(tmp["text"])
+ else:
+ for tmp in LangSegmenter.getTexts(text):
+ if langlist:
+ if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
+ textlist[-1] += tmp["text"]
+ continue
+ if tmp["lang"] == "en":
+ langlist.append(tmp["lang"])
+ else:
+ # 因无法区别中日韩文汉字,以用户输入为准
+ langlist.append(language)
+ textlist.append(tmp["text"])
+ phones_list = []
+ bert_list = []
+ norm_text_list = []
+ for i in range(len(textlist)):
+ lang = langlist[i]
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
+ bert = get_bert_inf(phones, word2ph, norm_text, lang)
+ phones_list.append(phones)
+ norm_text_list.append(norm_text)
+ bert_list.append(bert)
+ bert = torch.cat(bert_list, dim=1)
+ phones = sum(phones_list, [])
+ norm_text = "".join(norm_text_list)
+
+ if not final and len(phones) < 6:
+ return get_phones_and_bert("." + text, language, version, final=True)
+
+ return phones, bert.to(dtype), norm_text
+
+
+def merge_short_text_in_array(texts, threshold):
+ if (len(texts)) < 2:
+ return texts
+ result = []
+ text = ""
+ for ele in texts:
+ text += ele
+ if len(text) >= threshold:
+ result.append(text)
+ text = ""
+ if len(text) > 0:
+ if len(result) == 0:
+ result.append(text)
+ else:
+ result[len(result) - 1] += text
+ return result
+
+
+sr_model = None
+
+
+def audio_sr(audio, sr):
+ global sr_model
+ if sr_model is None:
+ from tools.audio_sr import AP_BWE
+
+ try:
+ sr_model = AP_BWE(infer_device, DictToAttrRecursive)
+ except FileNotFoundError:
+ gr.Warning(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
+ return audio.cpu().numpy(), sr
+ return sr_model(audio, sr)
+
+
+cache: dict[int, Any] = {}
+
+
+def get_tts_wav(
+ ref_wav_path,
+ prompt_text,
+ prompt_language,
+ text,
+ text_language,
+ how_to_cut=i18n("不切"),
+ top_k=20,
+ top_p=0.6,
+ temperature=0.6,
+ ref_free=False,
+ speed=1,
+ if_freeze=False,
+ inp_refs=None,
+ sample_steps=8,
+ if_sr=False,
+ pause_second=0.3,
+):
+ torch.set_grad_enabled(False)
+ debug = os.getenv("DEBUG") == "1"
+ ttfb_time = ttime()
+
+ if ref_wav_path:
+ pass
+ else:
+ gr.Warning(i18n("请上传参考音频"))
+ if text:
+ pass
+ else:
+ gr.Warning(i18n("请填入推理文本"))
+ t = []
+ if prompt_text is None or len(prompt_text) == 0:
+ ref_free = True
+ if model_version in v3v4set:
+ ref_free = False # s2v3暂不支持ref_free
+ t0 = ttime()
+ prompt_language = dict_language[prompt_language]
+ text_language = dict_language[text_language]
+
+ if not ref_free:
+ prompt_text = prompt_text.strip("\n")
+ if prompt_text[-1] not in splits:
+ prompt_text += "。" if prompt_language != "en" else "."
+ text = text.strip("\n")
+
+ zero_wav = np.zeros(
+ int(hps.data.sampling_rate * pause_second),
+ dtype=np.float16 if is_half is True else np.float32,
+ )
+ zero_wav_torch = torch.from_numpy(zero_wav)
+ if is_half is True:
+ zero_wav_torch = zero_wav_torch.half().to(infer_device)
+ else:
+ zero_wav_torch = zero_wav_torch.to(infer_device)
+ if not ref_free:
+ assert vq_model
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
+ if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
+ gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
+ raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
+ wav16k_t = torch.from_numpy(wav16k)
+ if is_half is True:
+ wav16k_t = wav16k_t.half().to(infer_device)
+ else:
+ wav16k_t = wav16k_t.to(infer_device)
+ wav16k_t = torch.cat([wav16k_t, zero_wav_torch])
+ ssl_content = ssl_model.model(wav16k_t.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
+ codes = vq_model.extract_latent(ssl_content)
+ prompt_semantic = codes[0, 0]
+ prompt = prompt_semantic.unsqueeze(0).to(device)
+ else:
+ prompt = torch.zeros((1, 0)).to(device, torch.int32)
+
+ t1 = ttime()
+ t.append(t1 - t0)
+
+ if how_to_cut == i18n("凑四句一切"):
+ text = cut1(text)
+ elif how_to_cut == i18n("凑50字一切"):
+ text = cut2(text)
+ elif how_to_cut == i18n("按中文句号。切"):
+ text = cut3(text)
+ elif how_to_cut == i18n("按英文句号.切"):
+ text = cut4(text)
+ elif how_to_cut == i18n("按标点符号切"):
+ text = cut5(text)
+ while "\n\n" in text:
+ text = text.replace("\n\n", "\n")
+ texts = text.split("\n")
+ texts = merge_short_text_in_array(texts, 5)
+ audio_opt = []
+ # s2v3暂不支持ref_free
+ if not ref_free:
+ phones1, bert1, _ = get_phones_and_bert(prompt_text, prompt_language, version)
+ else:
+ phones1, bert1 = [], torch.zeros(1024, 0).to(device, dtype)
+
+ infer_len: list[int] = []
+ infer_time: list[float] = []
+ assert vq_model
+
+ for i_text, text in enumerate(texts):
+ if len(text.strip()) == 0:
+ continue
+ if text[-1] not in splits:
+ text += "。" if text_language != "en" else "."
+ phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
+
+ bert = torch.cat([bert1, bert2], 1)
+ all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
+
+ bert = bert.to(device).unsqueeze(0)
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
+
+ t2 = ttime()
+ if i_text in cache and if_freeze is True:
+ pred_semantic = cache[i_text]
+ else:
+ t2s_request = T2SRequest(
+ [all_phoneme_ids.squeeze(0)],
+ all_phoneme_len,
+ prompt,
+ [bert.squeeze(0)],
+ valid_length=1,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ early_stop_num=1500,
+ use_cuda_graph=torch.cuda.is_available(), # Try to use CUDA Graph for all backend, fallback to normal if not applicapble
+ debug=debug,
+ )
+ assert t2s_engine
+ t2s_result = t2s_engine.generate(t2s_request)
+ if t2s_result.exception is not None:
+ console.print(t2s_result.traceback)
+ raise RuntimeError()
+ pred_semantic_list = t2s_result.result
+ assert pred_semantic_list, t2s_result.traceback
+ pred_semantic = pred_semantic_list[0].unsqueeze(0).to(infer_device)
+ infer_len.append(t2s_result.total_tokens)
+ infer_time.append(t2s_result.infer_speed[-1])
+
+ cache[i_text] = pred_semantic
+ t3 = ttime()
+ is_v2pro = model_version in {"v2Pro", "v2ProPlus"}
+
+ refers = []
+ if inp_refs:
+ for path in inp_refs:
+ try: # 这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer
+ refer, audio_tensor = get_spepc(hps, path.name, dtype, infer_device, is_v2pro)
+ refers.append(refer)
+ except Exception as e:
+ print(e)
+ traceback.print_exc()
+ if len(refers) == 0:
+ refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, infer_device, is_v2pro)
+ refers = [refers]
+ audio = vq_model.decode(
+ pred_semantic,
+ torch.LongTensor(phones2).to(infer_device).unsqueeze(0),
+ refers,
+ speed=speed,
+ )[0][0] # type: ignore
+
+ if i_text == 0:
+ ttfb_time = ttime() - ttfb_time
+ max_audio = torch.abs(audio).max() # 简单防止16bit爆音
+ if max_audio > 1:
+ audio = audio / max_audio
+ audio_opt.append(audio)
+ audio_opt.append(zero_wav_torch) # zero_wav
+ t4 = ttime()
+ t.extend([t2 - t1, t3 - t2, t4 - t3])
+ t1 = ttime()
+
+ audio_opt_t = torch.cat(audio_opt, 0) # np.concatenate
+ if model_version in {"v1", "v2", "v2Pro", "v2ProPlus"}:
+ opt_sr = 32000
+ elif model_version == "v3":
+ opt_sr = 24000
+ else:
+ opt_sr = 48000 # v4
+ audio_opt_n = audio_opt_t.cpu().numpy()
+
+ t0 = t[0]
+ t1 = sum(t[1::3])
+ t2 = sum(t[2::3])
+ t3 = sum(t[3::3])
+
+ infer_speed_avg = sum(infer_len) / sum(infer_time)
+ rtf_value = sum(t) / (audio_opt_n.__len__() / opt_sr)
+
+ console.print(f">> Time Stamps: {t0:.3f}\t{t1:.3f}\t{t2:.3f}\t{t3:.3f}")
+ console.print(f">> Infer Speed: {infer_speed_avg:.2f} Token/s")
+ console.print(f">> RTF: {rtf_value:.2f}")
+
+ if ttfb_time > 2:
+ console.print(f">> TTFB: {ttfb_time:.3f} s")
+ else:
+ console.print(f">> TTFB: {ttfb_time * 1000:.3f} ms")
+
+ yield opt_sr, (audio_opt_n * 32767).astype(np.int16)
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+def split(todo_text):
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
+ if todo_text[-1] not in splits:
+ todo_text += "。"
+ i_split_head = i_split_tail = 0
+ len_text = len(todo_text)
+ todo_texts = []
+ while 1:
+ if i_split_head >= len_text:
+ break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
+ if todo_text[i_split_head] in splits:
+ i_split_head += 1
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
+ i_split_tail = i_split_head
+ else:
+ i_split_head += 1
+ return todo_texts
+
+
+def cut1(inp):
+ inp = inp.strip("\n")
+ inps = split(inp)
+ split_idx: list[int | None] = list(range(0, len(inps) + 1, 4))
+ split_idx[-1] = None
+ if len(split_idx) > 1:
+ opts = []
+ for idx in range(len(split_idx) - 1):
+ opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
+ else:
+ opts = [inp]
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
+ return "\n".join(opts)
+
+
+def cut2(inp):
+ inp = inp.strip("\n")
+ inps = split(inp)
+ if len(inps) < 2:
+ return inp
+ opts = []
+ summ = 0
+ tmp_str = ""
+ for i in range(len(inps)):
+ summ += len(inps[i])
+ tmp_str += inps[i]
+ if summ > 50:
+ summ = 0
+ opts.append(tmp_str)
+ tmp_str = ""
+ if tmp_str != "":
+ opts.append(tmp_str)
+ if len(opts) > 1 and len(opts[-1]) < 50: # 如果最后一个太短了,和前一个合一起
+ opts[-2] = opts[-2] + opts[-1]
+ opts = opts[:-1]
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
+ return "\n".join(opts)
+
+
+def cut3(inp):
+ inp = inp.strip("\n")
+ opts = inp.strip("。").split("。")
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
+ return "\n".join(opts)
+
+
+def cut4(inp):
+ inp = inp.strip("\n")
+ opts = re.split(r"(? 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
+ items.append(char)
+ else:
+ items.append(char)
+ mergeitems.append("".join(items))
+ items = []
+ else:
+ items.append(char)
+
+ if items:
+ mergeitems.append("".join(items))
+
+ opt = [item for item in mergeitems if not set(item).issubset(punds)]
+ return "\n".join(opt)
+
+
+a = get_tts_wav(
+ "/Users/XXXXRT/Desktop/参考/不过呢因为有些特殊情况,所以我在一年半之前并没有这个,退网啊.wav",
+ "不过呢因为有些特殊情况,所以我在一年半之前并没有这个,退网啊",
+ i18n("中文"),
+ "我在我青春韶华的时候遇到了你,还记得刚刚开学的时候,那是第一次见你,我和我朋友在楼道间打闹的时候无意间瞟到了你正在学习时的侧颜",
+ i18n("中文"),
+)
+
+next(a)
+
+timer.clear()
+
+a = get_tts_wav(
+ "/Users/XXXXRT/Desktop/参考/Cream去能理解很多人的想法时,既然已经被这样想了,没有挽回的余地了.wav",
+ "去能理解很多人的想法时,既然已经被这样想了,没有挽回的余地了",
+ i18n("中文"),
+ "我在我青春韶华的时候遇到了你,还记得刚刚开学的时候,那是第一次见你,我和我朋友在楼道间打闹的时候无意间瞟到了你正在学习时的侧颜",
+ i18n("中文"),
+)
+
+
+next(a)
+
+timer.summary()
+
+a = get_tts_wav(
+ "/Users/XXXXRT/Desktop/参考/不过呢因为有些特殊情况,所以我在一年半之前并没有这个,退网啊.wav",
+ "不过呢因为有些特殊情况,所以我在一年半之前并没有这个,退网啊",
+ i18n("中文"),
+ "我在我青春韶华的时候遇到了你,还记得刚刚开学的时候,那是第一次见你,我和我朋友在楼道间打闹的时候无意间瞟到了你正在学习时的侧颜",
+ i18n("中文"),
+)
+
+
+next(a)
+
+timer.summary()
+
+print("-" * 15 + "test2" + "-" * 15)
diff --git a/webui.py b/webui.py
index 602e9e1b..3c20f361 100644
--- a/webui.py
+++ b/webui.py
@@ -8,6 +8,7 @@ import traceback
from functools import partial
from multiprocessing import cpu_count
from subprocess import Popen
+from typing import cast
import gradio as gr
import psutil
@@ -37,7 +38,15 @@ from config import (
webui_port_subfix,
webui_port_uvr5,
)
-from GPT_SoVITS.Accelerate import backends, console, logger
+from GPT_SoVITS.Accelerate import (
+ MLX,
+ PyTorch,
+ backends,
+ console,
+ logger,
+ quantization_methods_mlx,
+ quantization_methods_torch,
+)
from tools import my_utils
from tools.asr.config import asr_dict
from tools.assets import css, js, top_html
@@ -310,8 +319,8 @@ def change_tts_inference(
sovits_path: str,
batched_infer_enabled: bool,
backends_dropdown: str,
+ quantization_methods_dropdown: str | None,
):
- console.print(gpt_path, sovits_path)
global p_tts_inference
env = os.environ.copy()
cmd: list[str] = [python_exec, "-s", "-m"]
@@ -334,6 +343,7 @@ def change_tts_inference(
"-b", backends_dropdown,
"-d", f"{infer_device.type}:{gpu_number}",
"-p", str(webui_port_infer_tts),
+ "-q", str(quantization_methods_dropdown),
"--gpt", gpt_path,
"--sovits", sovits_path,
]
@@ -344,7 +354,6 @@ def change_tts_inference(
if p_tts_inference is None:
yield (
- process_info(process_name_tts, "opened"),
gr.update(visible=False),
gr.update(visible=True),
)
@@ -354,7 +363,6 @@ def change_tts_inference(
kill_process(p_tts_inference.pid, process_name_tts)
p_tts_inference = None
yield (
- process_info(process_name_tts, "closed"),
gr.update(visible=True),
gr.update(visible=False),
)
@@ -1280,6 +1288,21 @@ def changeBackend(flag: bool):
return gr.update(choices=backends_gradio, value=backends_gradio[-1][-1])
+def changeQuantization(backend: str, gradio_call=True):
+ backend = backend.lower().replace("-", "_")
+ if backend in MLX.backends:
+ choices = quantization_methods_mlx
+ elif backend in PyTorch.backends:
+ choices = quantization_methods_torch
+ else:
+ choices = [None]
+
+ if gradio_call:
+ return gr.update(choices=choices, value=None)
+ else:
+ return choices
+
+
GPU_INDEX.add(0)
GPU_INDEX_LIST = list(GPU_INDEX)
GPU_INDEX_LIST.sort()
@@ -1891,7 +1914,13 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
interactive=True,
)
with gr.Row(equal_height=True):
- tts_info = gr.Textbox(label=process_info(process_name_tts, "info"))
+ with gr.Column():
+ quantization_methods_dropdown = gr.Dropdown(
+ choices=cast(list, changeQuantization(backends_gradio[-1][-1], gradio_call=False)),
+ label=i18n("量化方法"),
+ value=None,
+ interactive=True,
+ )
open_tts = gr.Button(
value=process_info(process_name_tts, "open"), variant="primary", visible=True
)
@@ -1904,6 +1933,12 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
[batched_infer_enabled],
[backends_dropdown],
)
+ backends_dropdown.change(
+ changeQuantization,
+ [backends_dropdown],
+ [quantization_methods_dropdown],
+ )
+
open_tts.click(
change_tts_inference,
[
@@ -1912,8 +1947,9 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
SoVITS_dropdown,
batched_infer_enabled,
backends_dropdown,
+ quantization_methods_dropdown,
],
- [tts_info, open_tts, close_tts],
+ [open_tts, close_tts],
)
close_tts.click(
change_tts_inference,
@@ -1923,8 +1959,9 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
SoVITS_dropdown,
batched_infer_enabled,
backends_dropdown,
+ quantization_methods_dropdown,
],
- [tts_info, open_tts, close_tts],
+ [open_tts, close_tts],
)
button1Ba_open.click(
open1Ba,