From 13bb0b23a10e97f4556e9aaa39749dd4600285c4 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Fri, 5 Sep 2025 04:34:55 +0800 Subject: [PATCH] . --- GPT_SoVITS/Accelerate/MLX/__init__.py | 2 +- GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py | 11 ++++++++++- GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py | 8 +++++++- GPT_SoVITS/Accelerate/__init__.py | 8 +++++++- config.py | 10 ++++++---- 5 files changed, 31 insertions(+), 8 deletions(-) diff --git a/GPT_SoVITS/Accelerate/MLX/__init__.py b/GPT_SoVITS/Accelerate/MLX/__init__.py index 4042c0f1..6ee66db2 100644 --- a/GPT_SoVITS/Accelerate/MLX/__init__.py +++ b/GPT_SoVITS/Accelerate/MLX/__init__.py @@ -4,7 +4,7 @@ if importlib.util.find_spec("mlx") is not None: from .sample_funcs_mlx import sample_naive as sample_naive_mlx from .t2s_engine_mlx import T2SEngine as T2SEngineMLX - backends = ["mlx_static", "mlx_quantized", "mlx_varlen"] + backends = ["mlx_static", "mlx_quantized_mxfp4", "mlx_quantized_affine", "mlx_varlen"] else: backends = [] diff --git a/GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py b/GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py index 600f401e..a624b4a5 100644 --- a/GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py +++ b/GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py @@ -164,9 +164,18 @@ class T2SDecoder(T2SDecoderABC): self.kv_class = KVCacheHND self.group_size = 32 self.bits = 8 + self.mode = "affine" + + def set_mode(self, mode: str): + assert mode in ["affine", "mxfp4"] + self.mode = mode + if self.mode == "mxfp4": + self.bits = 4 + else: + self.bits = 8 def quantized(self): - nn.quantize(self, self.group_size, self.bits) + nn.quantize(self, self.group_size, self.bits, mode=self.mode) # for layer in self.h.layers: # nn.quantize(layer.feed_forward, self.group_size, self.bits) # nn.quantize(layer.attention, self.group_size, self.bits) diff --git a/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py b/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py index ec94cbee..fff9b191 100644 --- a/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py +++ b/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py @@ -213,7 +213,7 @@ class T2SEngine(T2SEngineProtocol): decoder_cls: type[T2SDecoderABC] = mlx_varlen.T2SDecoder case "MLX-Static": decoder_cls = mlx_static.T2SDecoder - case "MLX-Quantized": + case "MLX-Quantized-Affine" | "MLX-Quantized-MXFP4": decoder_cls = mlx_quantized.T2SDecoder case _: raise RuntimeError(f"Backend {backend} Not Found") @@ -226,6 +226,12 @@ class T2SEngine(T2SEngineProtocol): mx.eval(decoder) if "Quantized" in backend and isinstance(decoder, mlx_quantized.T2SDecoder): + if backend == "MLX-Quantized-Affine": + decoder.set_mode("affine") + elif backend == "MLX-Quantized-MXFP4": + decoder.set_mode("mxfp4") + else: + raise RuntimeError(f"Quantized Backend {backend} Not Supported") decoder.quantized() mx.eval(decoder) diff --git a/GPT_SoVITS/Accelerate/__init__.py b/GPT_SoVITS/Accelerate/__init__.py index 6ecdba15..797fe1d0 100644 --- a/GPT_SoVITS/Accelerate/__init__.py +++ b/GPT_SoVITS/Accelerate/__init__.py @@ -6,7 +6,13 @@ from .PyTorch.structs import T2SEngineProtocol backends = PyTorch.backends + MLX.backends backends = [ - b.replace("_", "-").title().replace("Mlx", "MLX").replace("Mps", "MPS").replace("Cuda", "CUDA") for b in backends + b.replace("_", "-") + .title() + .replace("Mlx", "MLX") + .replace("Mps", "MPS") + .replace("Cuda", "CUDA") + .replace("Mxfp4", "MXFP4") + for b in backends ] diff --git a/config.py b/config.py index 333f7f22..c2e5169a 100644 --- a/config.py +++ b/config.py @@ -80,9 +80,6 @@ def get_weights_names(i18n): for key, value in name2sovits_path.items(): if os.path.exists(value): SoVITS_names.append((i18n(key), value)) - for key, value in pretrained_sovits_name.items(): - if key in {"v3", "v4", "v1"}: - SoVITS_names.append((value, value)) for path in SoVITS_weight_root: if not os.path.exists(path): continue @@ -93,15 +90,20 @@ def get_weights_names(i18n): for key, value in name2gpt_path.items(): if os.path.exists(value): GPT_names.append((i18n(key), value)) - GPT_names.append((pretrained_gpt_name["v1"], pretrained_gpt_name["v1"])) for path in GPT_weight_root: if not os.path.exists(path): continue for name in os.listdir(path): if name.endswith(".ckpt"): GPT_names.append((f"{path}/{name}", f"{path}/{name}")) + SoVITS_names = sorted(SoVITS_names, key=custom_sort_key) GPT_names = sorted(GPT_names, key=custom_sort_key) + + for key, value in pretrained_sovits_name.items(): + if key in {"v3", "v4", "v1"}: + SoVITS_names.append((value, value)) + GPT_names.append((pretrained_gpt_name["v1"], pretrained_gpt_name["v1"])) return SoVITS_names, GPT_names