From 671d97fbf644c18ccb23947f9561496d2ea1f215 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Fri, 24 Oct 2025 03:18:39 +0100 Subject: [PATCH] . --- GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py | 2 ++ GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py | 5 ++++- GPT_SoVITS/TTS_infer_pack/TTS.py | 11 ++++++++++- GPT_SoVITS/inference_webui.py | 2 ++ 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py b/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py index 8f39611c..076c079e 100644 --- a/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py +++ b/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py @@ -25,6 +25,8 @@ class T2SEngine(T2SEngineProtocol): decoder_model: T2SDecoderABC, device: mx.Device | str = mx.Device(mx.cpu), dtype: torch.dtype | mx.Dtype = torch.float32, + *args, + **kwds, ) -> None: if isinstance(device, str): match device: diff --git a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py index 9c560666..6f9319e3 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py +++ b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py @@ -24,6 +24,9 @@ class T2SEngine(T2SEngineProtocol): decoder_model: T2SDecoderABC, device: torch.device = torch.device("cpu"), dtype: torch.dtype = torch.float32, + cache_size: int = 5, + *args, + **kwds, ) -> None: assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"} assert dtype in {torch.float16, torch.bfloat16, torch.float32} @@ -34,7 +37,7 @@ class T2SEngine(T2SEngineProtocol): self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype) # self.decoder_model.compile() - self.graphcache: CUDAGraphCacheABC = decoder_model.graph_cache_class(self.decoder_model) + self.graphcache: CUDAGraphCacheABC = decoder_model.graph_cache_class(self.decoder_model, cache_size) def _handle_request(self, request: T2SRequest): with self.device: diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 68476762..ef850337 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -364,7 +364,13 @@ class TTS_Config: class TTS: - def __init__(self, configs: dict | str | TTS_Config, ar_backend: str = backends[-1], quantization: Any = None): + def __init__( + self, + configs: dict | str | TTS_Config, + ar_backend: str = backends[-1], + quantization: Any = None, + cache_size: int = 5, + ): if isinstance(configs, TTS_Config): self.configs = configs else: @@ -405,6 +411,7 @@ class TTS: self.backend: str = ar_backend self.quantization: Any = quantization + self.cache_size: int = cache_size self._init_models() @@ -529,6 +536,7 @@ class TTS: ), "mx.gpu" if self.configs.device.type != "cpu" else "mx.cpu", dtype=self.precision, + cache_size=self.cache_size, ) else: t2s_engine = PyTorch.T2SEngineTorch( @@ -537,6 +545,7 @@ class TTS: ), self.configs.device if not torch.mps.is_available() else torch.device("cpu"), dtype=self.precision, + cache_size=self.cache_size, ) self.t2s_model = t2s_engine diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index d7bfda5a..77cf9155 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -412,6 +412,7 @@ async def change_gpt_weights(gpt_path): MLX.T2SEngineMLX.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization), "mx.gpu" if infer_device.type != "cpu" else "mx.cpu", dtype=dtype, + cache_size=1, ) # t2s_engine.decoder_model.compile() total = sum((p[-1].size for p in mxutils.tree_flatten(t2s_engine.decoder_model.parameters()))) # type: ignore @@ -420,6 +421,7 @@ async def change_gpt_weights(gpt_path): PyTorch.T2SEngineTorch.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization), device, dtype=dtype, + cache_size=1, ) # t2s_engine.decoder_model.compile() total = sum(p.numel() for p in t2s_engine.decoder_model.parameters())