mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-07-01 19:18:21 +08:00
.
This commit is contained in:
parent
ebb1303627
commit
0b667c7b67
@ -126,6 +126,7 @@ class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
self,
|
||||
decoder: T2SDecoder,
|
||||
) -> None:
|
||||
self.is_applicable = True
|
||||
super().__init__(decoder)
|
||||
|
||||
def release_graph(self, session: T2SSession):
|
||||
|
||||
@ -117,6 +117,7 @@ class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
self,
|
||||
decoder,
|
||||
) -> None:
|
||||
self.is_applicable = False
|
||||
super().__init__(decoder)
|
||||
if torch.cuda.is_available():
|
||||
self.attn_mask = (
|
||||
|
||||
@ -125,6 +125,7 @@ class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
self,
|
||||
decoder: T2SDecoder,
|
||||
) -> None:
|
||||
self.is_applicable = False
|
||||
super().__init__(decoder)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
|
||||
@ -117,6 +117,7 @@ class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
self,
|
||||
decoder,
|
||||
) -> None:
|
||||
self.is_applicable = True
|
||||
super().__init__(decoder)
|
||||
if torch.cuda.is_available():
|
||||
self.attn_mask = (
|
||||
|
||||
@ -132,7 +132,8 @@ class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
self,
|
||||
decoder,
|
||||
) -> None:
|
||||
super().__init__(decoder, False)
|
||||
self.is_applicable = False
|
||||
super().__init__(decoder)
|
||||
|
||||
def release_graph(self, session: T2SSession):
|
||||
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
|
||||
|
||||
@ -69,7 +69,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
|
||||
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
||||
else:
|
||||
if request.use_cuda_graph and session.graph is None and torch.cuda.is_available():
|
||||
if request.use_cuda_graph and session.graph is None and self.graphcache.is_applicable and torch.cuda.is_available():
|
||||
self.graphcache.assign_graph(session)
|
||||
|
||||
with torch_profiler.record("AR"):
|
||||
|
||||
@ -558,9 +558,10 @@ class CUDAGraphCacheABC(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoder: T2SDecoderABC,
|
||||
enabled: bool = False,
|
||||
) -> None:
|
||||
if torch.cuda.is_available() and enabled:
|
||||
self.is_applicable: bool
|
||||
|
||||
if torch.cuda.is_available() and self.is_applicable:
|
||||
self.device: torch.device = decoder.device
|
||||
self.dtype = decoder.bert_proj.bias.dtype
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user