This commit is contained in:
XXXXRT666 2025-10-24 04:20:15 +01:00
parent f25fcfb0e4
commit 0651cb5d01

View File

@ -692,6 +692,7 @@ class CUDAGraphCacheABC(ABC):
self.graph_cache: dict[int, Queue[CUDAGraphStateABC]] = {}
if torch.cuda.is_available() and torch.version.cuda is not None and os.environ.get("CUDAGraph", "1") != "0":
self.graph_cache[1] = Queue()
self.create_graph_cache(1)
def __getitem__(self, bsz: int) -> CUDAGraphStateABC: