mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-28 00:38:15 +08:00
.
This commit is contained in:
parent
ca8c9ea768
commit
671d97fbf6
@ -25,6 +25,8 @@ class T2SEngine(T2SEngineProtocol):
|
||||
decoder_model: T2SDecoderABC,
|
||||
device: mx.Device | str = mx.Device(mx.cpu),
|
||||
dtype: torch.dtype | mx.Dtype = torch.float32,
|
||||
*args,
|
||||
**kwds,
|
||||
) -> None:
|
||||
if isinstance(device, str):
|
||||
match device:
|
||||
|
||||
@ -24,6 +24,9 @@ class T2SEngine(T2SEngineProtocol):
|
||||
decoder_model: T2SDecoderABC,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
cache_size: int = 5,
|
||||
*args,
|
||||
**kwds,
|
||||
) -> None:
|
||||
assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"}
|
||||
assert dtype in {torch.float16, torch.bfloat16, torch.float32}
|
||||
@ -34,7 +37,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
|
||||
# self.decoder_model.compile()
|
||||
|
||||
self.graphcache: CUDAGraphCacheABC = decoder_model.graph_cache_class(self.decoder_model)
|
||||
self.graphcache: CUDAGraphCacheABC = decoder_model.graph_cache_class(self.decoder_model, cache_size)
|
||||
|
||||
def _handle_request(self, request: T2SRequest):
|
||||
with self.device:
|
||||
|
||||
@ -364,7 +364,13 @@ class TTS_Config:
|
||||
|
||||
|
||||
class TTS:
|
||||
def __init__(self, configs: dict | str | TTS_Config, ar_backend: str = backends[-1], quantization: Any = None):
|
||||
def __init__(
|
||||
self,
|
||||
configs: dict | str | TTS_Config,
|
||||
ar_backend: str = backends[-1],
|
||||
quantization: Any = None,
|
||||
cache_size: int = 5,
|
||||
):
|
||||
if isinstance(configs, TTS_Config):
|
||||
self.configs = configs
|
||||
else:
|
||||
@ -405,6 +411,7 @@ class TTS:
|
||||
|
||||
self.backend: str = ar_backend
|
||||
self.quantization: Any = quantization
|
||||
self.cache_size: int = cache_size
|
||||
|
||||
self._init_models()
|
||||
|
||||
@ -529,6 +536,7 @@ class TTS:
|
||||
),
|
||||
"mx.gpu" if self.configs.device.type != "cpu" else "mx.cpu",
|
||||
dtype=self.precision,
|
||||
cache_size=self.cache_size,
|
||||
)
|
||||
else:
|
||||
t2s_engine = PyTorch.T2SEngineTorch(
|
||||
@ -537,6 +545,7 @@ class TTS:
|
||||
),
|
||||
self.configs.device if not torch.mps.is_available() else torch.device("cpu"),
|
||||
dtype=self.precision,
|
||||
cache_size=self.cache_size,
|
||||
)
|
||||
|
||||
self.t2s_model = t2s_engine
|
||||
|
||||
@ -412,6 +412,7 @@ async def change_gpt_weights(gpt_path):
|
||||
MLX.T2SEngineMLX.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization),
|
||||
"mx.gpu" if infer_device.type != "cpu" else "mx.cpu",
|
||||
dtype=dtype,
|
||||
cache_size=1,
|
||||
)
|
||||
# t2s_engine.decoder_model.compile()
|
||||
total = sum((p[-1].size for p in mxutils.tree_flatten(t2s_engine.decoder_model.parameters()))) # type: ignore
|
||||
@ -420,6 +421,7 @@ async def change_gpt_weights(gpt_path):
|
||||
PyTorch.T2SEngineTorch.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization),
|
||||
device,
|
||||
dtype=dtype,
|
||||
cache_size=1,
|
||||
)
|
||||
# t2s_engine.decoder_model.compile()
|
||||
total = sum(p.numel() for p in t2s_engine.decoder_model.parameters())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user