mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-07-02 11:48:12 +08:00
.
This commit is contained in:
parent
6f2e2c1e8d
commit
13bb0b23a1
@ -4,7 +4,7 @@ if importlib.util.find_spec("mlx") is not None:
|
||||
from .sample_funcs_mlx import sample_naive as sample_naive_mlx
|
||||
from .t2s_engine_mlx import T2SEngine as T2SEngineMLX
|
||||
|
||||
backends = ["mlx_static", "mlx_quantized", "mlx_varlen"]
|
||||
backends = ["mlx_static", "mlx_quantized_mxfp4", "mlx_quantized_affine", "mlx_varlen"]
|
||||
else:
|
||||
backends = []
|
||||
|
||||
|
||||
@ -164,9 +164,18 @@ class T2SDecoder(T2SDecoderABC):
|
||||
self.kv_class = KVCacheHND
|
||||
self.group_size = 32
|
||||
self.bits = 8
|
||||
self.mode = "affine"
|
||||
|
||||
def set_mode(self, mode: str):
|
||||
assert mode in ["affine", "mxfp4"]
|
||||
self.mode = mode
|
||||
if self.mode == "mxfp4":
|
||||
self.bits = 4
|
||||
else:
|
||||
self.bits = 8
|
||||
|
||||
def quantized(self):
|
||||
nn.quantize(self, self.group_size, self.bits)
|
||||
nn.quantize(self, self.group_size, self.bits, mode=self.mode)
|
||||
# for layer in self.h.layers:
|
||||
# nn.quantize(layer.feed_forward, self.group_size, self.bits)
|
||||
# nn.quantize(layer.attention, self.group_size, self.bits)
|
||||
|
||||
@ -213,7 +213,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
decoder_cls: type[T2SDecoderABC] = mlx_varlen.T2SDecoder
|
||||
case "MLX-Static":
|
||||
decoder_cls = mlx_static.T2SDecoder
|
||||
case "MLX-Quantized":
|
||||
case "MLX-Quantized-Affine" | "MLX-Quantized-MXFP4":
|
||||
decoder_cls = mlx_quantized.T2SDecoder
|
||||
case _:
|
||||
raise RuntimeError(f"Backend {backend} Not Found")
|
||||
@ -226,6 +226,12 @@ class T2SEngine(T2SEngineProtocol):
|
||||
mx.eval(decoder)
|
||||
|
||||
if "Quantized" in backend and isinstance(decoder, mlx_quantized.T2SDecoder):
|
||||
if backend == "MLX-Quantized-Affine":
|
||||
decoder.set_mode("affine")
|
||||
elif backend == "MLX-Quantized-MXFP4":
|
||||
decoder.set_mode("mxfp4")
|
||||
else:
|
||||
raise RuntimeError(f"Quantized Backend {backend} Not Supported")
|
||||
decoder.quantized()
|
||||
mx.eval(decoder)
|
||||
|
||||
|
||||
@ -6,7 +6,13 @@ from .PyTorch.structs import T2SEngineProtocol
|
||||
backends = PyTorch.backends + MLX.backends
|
||||
|
||||
backends = [
|
||||
b.replace("_", "-").title().replace("Mlx", "MLX").replace("Mps", "MPS").replace("Cuda", "CUDA") for b in backends
|
||||
b.replace("_", "-")
|
||||
.title()
|
||||
.replace("Mlx", "MLX")
|
||||
.replace("Mps", "MPS")
|
||||
.replace("Cuda", "CUDA")
|
||||
.replace("Mxfp4", "MXFP4")
|
||||
for b in backends
|
||||
]
|
||||
|
||||
|
||||
|
||||
10
config.py
10
config.py
@ -80,9 +80,6 @@ def get_weights_names(i18n):
|
||||
for key, value in name2sovits_path.items():
|
||||
if os.path.exists(value):
|
||||
SoVITS_names.append((i18n(key), value))
|
||||
for key, value in pretrained_sovits_name.items():
|
||||
if key in {"v3", "v4", "v1"}:
|
||||
SoVITS_names.append((value, value))
|
||||
for path in SoVITS_weight_root:
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
@ -93,15 +90,20 @@ def get_weights_names(i18n):
|
||||
for key, value in name2gpt_path.items():
|
||||
if os.path.exists(value):
|
||||
GPT_names.append((i18n(key), value))
|
||||
GPT_names.append((pretrained_gpt_name["v1"], pretrained_gpt_name["v1"]))
|
||||
for path in GPT_weight_root:
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
for name in os.listdir(path):
|
||||
if name.endswith(".ckpt"):
|
||||
GPT_names.append((f"{path}/{name}", f"{path}/{name}"))
|
||||
|
||||
SoVITS_names = sorted(SoVITS_names, key=custom_sort_key)
|
||||
GPT_names = sorted(GPT_names, key=custom_sort_key)
|
||||
|
||||
for key, value in pretrained_sovits_name.items():
|
||||
if key in {"v3", "v4", "v1"}:
|
||||
SoVITS_names.append((value, value))
|
||||
GPT_names.append((pretrained_gpt_name["v1"], pretrained_gpt_name["v1"]))
|
||||
return SoVITS_names, GPT_names
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user