This commit is contained in:
XXXXRT666 2025-10-24 03:18:39 +01:00
parent ca8c9ea768
commit 671d97fbf6
4 changed files with 18 additions and 2 deletions

View File

@ -25,6 +25,8 @@ class T2SEngine(T2SEngineProtocol):
decoder_model: T2SDecoderABC,
device: mx.Device | str = mx.Device(mx.cpu),
dtype: torch.dtype | mx.Dtype = torch.float32,
*args,
**kwds,
) -> None:
if isinstance(device, str):
match device:

View File

@ -24,6 +24,9 @@ class T2SEngine(T2SEngineProtocol):
decoder_model: T2SDecoderABC,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
cache_size: int = 5,
*args,
**kwds,
) -> None:
assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"}
assert dtype in {torch.float16, torch.bfloat16, torch.float32}
@ -34,7 +37,7 @@ class T2SEngine(T2SEngineProtocol):
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
# self.decoder_model.compile()
self.graphcache: CUDAGraphCacheABC = decoder_model.graph_cache_class(self.decoder_model)
self.graphcache: CUDAGraphCacheABC = decoder_model.graph_cache_class(self.decoder_model, cache_size)
def _handle_request(self, request: T2SRequest):
with self.device:

View File

@ -364,7 +364,13 @@ class TTS_Config:
class TTS:
def __init__(self, configs: dict | str | TTS_Config, ar_backend: str = backends[-1], quantization: Any = None):
def __init__(
self,
configs: dict | str | TTS_Config,
ar_backend: str = backends[-1],
quantization: Any = None,
cache_size: int = 5,
):
if isinstance(configs, TTS_Config):
self.configs = configs
else:
@ -405,6 +411,7 @@ class TTS:
self.backend: str = ar_backend
self.quantization: Any = quantization
self.cache_size: int = cache_size
self._init_models()
@ -529,6 +536,7 @@ class TTS:
),
"mx.gpu" if self.configs.device.type != "cpu" else "mx.cpu",
dtype=self.precision,
cache_size=self.cache_size,
)
else:
t2s_engine = PyTorch.T2SEngineTorch(
@ -537,6 +545,7 @@ class TTS:
),
self.configs.device if not torch.mps.is_available() else torch.device("cpu"),
dtype=self.precision,
cache_size=self.cache_size,
)
self.t2s_model = t2s_engine

View File

@ -412,6 +412,7 @@ async def change_gpt_weights(gpt_path):
MLX.T2SEngineMLX.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization),
"mx.gpu" if infer_device.type != "cpu" else "mx.cpu",
dtype=dtype,
cache_size=1,
)
# t2s_engine.decoder_model.compile()
total = sum((p[-1].size for p in mxutils.tree_flatten(t2s_engine.decoder_model.parameters()))) # type: ignore
@ -420,6 +421,7 @@ async def change_gpt_weights(gpt_path):
PyTorch.T2SEngineTorch.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization),
device,
dtype=dtype,
cache_size=1,
)
# t2s_engine.decoder_model.compile()
total = sum(p.numel() for p in t2s_engine.decoder_model.parameters())