From 0b667c7b67314cefe5f07e7090615399df23a94f Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Thu, 4 Sep 2025 21:52:29 +0000 Subject: [PATCH] . --- .../PyTorch/backends/flash_attn_varlen_cuda_graph.py | 1 + .../Accelerate/PyTorch/backends/mps_flash_attn_varlen.py | 1 + .../PyTorch/backends/sage_attn_varlen_cuda_graph.py | 1 + .../Accelerate/PyTorch/backends/torch_static_cuda_graph.py | 1 + GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py | 3 ++- GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py | 2 +- GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py | 5 +++-- 7 files changed, 10 insertions(+), 4 deletions(-) diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py b/GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py index 666c1b63..d8a33f4e 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py +++ b/GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py @@ -126,6 +126,7 @@ class CUDAGraphCache(CUDAGraphCacheABC): self, decoder: T2SDecoder, ) -> None: + self.is_applicable = True super().__init__(decoder) def release_graph(self, session: T2SSession): diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py b/GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py index 6f15f51e..6c83fcc3 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py +++ b/GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py @@ -117,6 +117,7 @@ class CUDAGraphCache(CUDAGraphCacheABC): self, decoder, ) -> None: + self.is_applicable = False super().__init__(decoder) if torch.cuda.is_available(): self.attn_mask = ( diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py b/GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py index d94b567c..3212e755 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py +++ b/GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py @@ -125,6 +125,7 @@ class CUDAGraphCache(CUDAGraphCacheABC): self, decoder: T2SDecoder, ) -> None: + self.is_applicable = False super().__init__(decoder) if torch.cuda.is_available(): diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py b/GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py index 6f15f51e..8eb677f7 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py +++ b/GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py @@ -117,6 +117,7 @@ class CUDAGraphCache(CUDAGraphCacheABC): self, decoder, ) -> None: + self.is_applicable = True super().__init__(decoder) if torch.cuda.is_available(): self.attn_mask = ( diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py b/GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py index d079e9af..a0d1be61 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py +++ b/GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py @@ -132,7 +132,8 @@ class CUDAGraphCache(CUDAGraphCacheABC): self, decoder, ) -> None: - super().__init__(decoder, False) + self.is_applicable = False + super().__init__(decoder) def release_graph(self, session: T2SSession): raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model") diff --git a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py index e0402f7a..c3ba17bc 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py +++ b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py @@ -69,7 +69,7 @@ class T2SEngine(T2SEngineProtocol): xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask) xy_dec = xy_dec[None, batch_idx, session.input_pos - 1] else: - if request.use_cuda_graph and session.graph is None and torch.cuda.is_available(): + if request.use_cuda_graph and session.graph is None and self.graphcache.is_applicable and torch.cuda.is_available(): self.graphcache.assign_graph(session) with torch_profiler.record("AR"): diff --git a/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py b/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py index ad2bddcd..5d5d57bb 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py +++ b/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py @@ -558,9 +558,10 @@ class CUDAGraphCacheABC(ABC): def __init__( self, decoder: T2SDecoderABC, - enabled: bool = False, ) -> None: - if torch.cuda.is_available() and enabled: + self.is_applicable: bool + + if torch.cuda.is_available() and self.is_applicable: self.device: torch.device = decoder.device self.dtype = decoder.bert_proj.bias.dtype