diff --git a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py index 423f4953..179f870f 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py +++ b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py @@ -32,6 +32,7 @@ class T2SEngine(T2SEngineProtocol): self.dtype = dtype self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype) + self.decoder_model.compile() self.graphcache: CUDAGraphCacheABC = self.init_cache() @@ -183,7 +184,7 @@ class T2SEngine(T2SEngineProtocol): gc.collect() torch_profiler.end() - if request.use_cuda_graph and torch.cuda.is_available(): + if request.use_cuda_graph and self.graphcache.is_applicable: self.graphcache.release_graph(session) return session.y_results[: request.valid_length], infer_speed, infer_time @@ -194,6 +195,9 @@ class T2SEngine(T2SEngineProtocol): t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success") except Exception as e: t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc()) + if self.decoder_model.compiled: + self.decoder_model.save_compile_cache() + self.compiled = None return t2s_result @staticmethod