This commit is contained in:
XXXXRT666 2025-09-04 21:52:29 +00:00
parent ebb1303627
commit 0b667c7b67
7 changed files with 10 additions and 4 deletions

View File

@ -126,6 +126,7 @@ class CUDAGraphCache(CUDAGraphCacheABC):
self,
decoder: T2SDecoder,
) -> None:
self.is_applicable = True
super().__init__(decoder)
def release_graph(self, session: T2SSession):

View File

@ -117,6 +117,7 @@ class CUDAGraphCache(CUDAGraphCacheABC):
self,
decoder,
) -> None:
self.is_applicable = False
super().__init__(decoder)
if torch.cuda.is_available():
self.attn_mask = (

View File

@ -125,6 +125,7 @@ class CUDAGraphCache(CUDAGraphCacheABC):
self,
decoder: T2SDecoder,
) -> None:
self.is_applicable = False
super().__init__(decoder)
if torch.cuda.is_available():

View File

@ -117,6 +117,7 @@ class CUDAGraphCache(CUDAGraphCacheABC):
self,
decoder,
) -> None:
self.is_applicable = True
super().__init__(decoder)
if torch.cuda.is_available():
self.attn_mask = (

View File

@ -132,7 +132,8 @@ class CUDAGraphCache(CUDAGraphCacheABC):
self,
decoder,
) -> None:
super().__init__(decoder, False)
self.is_applicable = False
super().__init__(decoder)
def release_graph(self, session: T2SSession):
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")

View File

@ -69,7 +69,7 @@ class T2SEngine(T2SEngineProtocol):
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
else:
if request.use_cuda_graph and session.graph is None and torch.cuda.is_available():
if request.use_cuda_graph and session.graph is None and self.graphcache.is_applicable and torch.cuda.is_available():
self.graphcache.assign_graph(session)
with torch_profiler.record("AR"):

View File

@ -558,9 +558,10 @@ class CUDAGraphCacheABC(ABC):
def __init__(
self,
decoder: T2SDecoderABC,
enabled: bool = False,
) -> None:
if torch.cuda.is_available() and enabled:
self.is_applicable: bool
if torch.cuda.is_available() and self.is_applicable:
self.device: torch.device = decoder.device
self.dtype = decoder.bert_proj.bias.dtype