This commit is contained in:
XXXXRT666 2025-09-30 06:57:43 +01:00
parent ac42f69e25
commit 630ce73c26

View File

@ -32,6 +32,7 @@ class T2SEngine(T2SEngineProtocol):
self.dtype = dtype
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
self.decoder_model.compile()
self.graphcache: CUDAGraphCacheABC = self.init_cache()
@ -183,7 +184,7 @@ class T2SEngine(T2SEngineProtocol):
gc.collect()
torch_profiler.end()
if request.use_cuda_graph and torch.cuda.is_available():
if request.use_cuda_graph and self.graphcache.is_applicable:
self.graphcache.release_graph(session)
return session.y_results[: request.valid_length], infer_speed, infer_time
@ -194,6 +195,9 @@ class T2SEngine(T2SEngineProtocol):
t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success")
except Exception as e:
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
if self.decoder_model.compiled:
self.decoder_model.save_compile_cache()
self.compiled = None
return t2s_result
@staticmethod