This commit is contained in:
XXXXRT666 2025-09-05 04:34:55 +08:00
parent 6f2e2c1e8d
commit 13bb0b23a1
5 changed files with 31 additions and 8 deletions

View File

@ -4,7 +4,7 @@ if importlib.util.find_spec("mlx") is not None:
from .sample_funcs_mlx import sample_naive as sample_naive_mlx
from .t2s_engine_mlx import T2SEngine as T2SEngineMLX
backends = ["mlx_static", "mlx_quantized", "mlx_varlen"]
backends = ["mlx_static", "mlx_quantized_mxfp4", "mlx_quantized_affine", "mlx_varlen"]
else:
backends = []

View File

@ -164,9 +164,18 @@ class T2SDecoder(T2SDecoderABC):
self.kv_class = KVCacheHND
self.group_size = 32
self.bits = 8
self.mode = "affine"
def set_mode(self, mode: str):
assert mode in ["affine", "mxfp4"]
self.mode = mode
if self.mode == "mxfp4":
self.bits = 4
else:
self.bits = 8
def quantized(self):
nn.quantize(self, self.group_size, self.bits)
nn.quantize(self, self.group_size, self.bits, mode=self.mode)
# for layer in self.h.layers:
# nn.quantize(layer.feed_forward, self.group_size, self.bits)
# nn.quantize(layer.attention, self.group_size, self.bits)

View File

@ -213,7 +213,7 @@ class T2SEngine(T2SEngineProtocol):
decoder_cls: type[T2SDecoderABC] = mlx_varlen.T2SDecoder
case "MLX-Static":
decoder_cls = mlx_static.T2SDecoder
case "MLX-Quantized":
case "MLX-Quantized-Affine" | "MLX-Quantized-MXFP4":
decoder_cls = mlx_quantized.T2SDecoder
case _:
raise RuntimeError(f"Backend {backend} Not Found")
@ -226,6 +226,12 @@ class T2SEngine(T2SEngineProtocol):
mx.eval(decoder)
if "Quantized" in backend and isinstance(decoder, mlx_quantized.T2SDecoder):
if backend == "MLX-Quantized-Affine":
decoder.set_mode("affine")
elif backend == "MLX-Quantized-MXFP4":
decoder.set_mode("mxfp4")
else:
raise RuntimeError(f"Quantized Backend {backend} Not Supported")
decoder.quantized()
mx.eval(decoder)

View File

@ -6,7 +6,13 @@ from .PyTorch.structs import T2SEngineProtocol
backends = PyTorch.backends + MLX.backends
backends = [
b.replace("_", "-").title().replace("Mlx", "MLX").replace("Mps", "MPS").replace("Cuda", "CUDA") for b in backends
b.replace("_", "-")
.title()
.replace("Mlx", "MLX")
.replace("Mps", "MPS")
.replace("Cuda", "CUDA")
.replace("Mxfp4", "MXFP4")
for b in backends
]

View File

@ -80,9 +80,6 @@ def get_weights_names(i18n):
for key, value in name2sovits_path.items():
if os.path.exists(value):
SoVITS_names.append((i18n(key), value))
for key, value in pretrained_sovits_name.items():
if key in {"v3", "v4", "v1"}:
SoVITS_names.append((value, value))
for path in SoVITS_weight_root:
if not os.path.exists(path):
continue
@ -93,15 +90,20 @@ def get_weights_names(i18n):
for key, value in name2gpt_path.items():
if os.path.exists(value):
GPT_names.append((i18n(key), value))
GPT_names.append((pretrained_gpt_name["v1"], pretrained_gpt_name["v1"]))
for path in GPT_weight_root:
if not os.path.exists(path):
continue
for name in os.listdir(path):
if name.endswith(".ckpt"):
GPT_names.append((f"{path}/{name}", f"{path}/{name}"))
SoVITS_names = sorted(SoVITS_names, key=custom_sort_key)
GPT_names = sorted(GPT_names, key=custom_sort_key)
for key, value in pretrained_sovits_name.items():
if key in {"v3", "v4", "v1"}:
SoVITS_names.append((value, value))
GPT_names.append((pretrained_gpt_name["v1"], pretrained_gpt_name["v1"]))
return SoVITS_names, GPT_names