diff --git a/.gitignore b/.gitignore index ae31dad9..d9e61e9f 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ ref_audios tools/AP_BWE/24kto48k/* !tools/AP_BWE/24kto48k/readme.txt onnx_export +compile_cache # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py b/GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py index 3343e948..6b699198 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py +++ b/GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py @@ -114,6 +114,9 @@ class T2SDecoder(T2SDecoderABC): self.kv_class = KVCacheNHD + def compile(self, *args, **kwds): + pass + def post_forward(self, idx: int, session: T2SSession) -> None: return super().post_forward(idx, session) @@ -133,6 +136,8 @@ class CUDAGraphCache(CUDAGraphCacheABC): if session.id == self.id: self.assigned = False else: + assert session.graph + session.graph.reset() del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache def get_cache_graph(self, session: T2SSession): diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py b/GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py index 6ff762ed..ddd150c3 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py +++ b/GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py @@ -107,6 +107,9 @@ class T2SDecoder(T2SDecoderABC): self.kv_class = KVCacheHND + def compile(self, *args, **kwds): + pass + def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]: return list(), dict(cu_seqlens_q=session.cu_seqlens_q, cu_seqlens_kv=session.cu_seqlens_kv) @@ -136,6 +139,8 @@ class CUDAGraphCache(CUDAGraphCacheABC): if session.id == self.id: self.assigned = False else: + assert session.graph + session.graph.reset() del ( session.graph, session.xy_pos_, diff --git a/GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py b/GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py index 7bd1bd70..961bc19a 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py +++ b/GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py @@ -130,6 +130,8 @@ class CUDAGraphCache(CUDAGraphCacheABC): if session.id == self.id: self.assigned = False else: + assert session.graph + session.graph.reset() del ( session.graph, session.xy_pos_, diff --git a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py index 179f870f..b1d7cbf0 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py +++ b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py @@ -32,7 +32,7 @@ class T2SEngine(T2SEngineProtocol): self.dtype = dtype self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype) - self.decoder_model.compile() + # self.decoder_model.compile() self.graphcache: CUDAGraphCacheABC = self.init_cache() diff --git a/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py b/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py index fe8bf75e..d86d44e4 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py +++ b/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py @@ -6,9 +6,12 @@ from __future__ import annotations import math import os +import pickle +import platform import random from abc import ABC, abstractmethod from contextlib import nullcontext +from pathlib import Path from typing import MutableSequence import torch @@ -17,6 +20,8 @@ import torch.nn.functional as F from torch.cuda.graphs import CUDAGraph from torch.profiler import ProfilerAction, tensorboard_trace_handler +from tools.my_utils import get_machine_id + from . import nn from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession @@ -453,6 +458,7 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol): self.kv_class: type[KVCacheABC] self.GraphCache: CUDAGraphCacheABC | None + self.compiled: bool = False self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size) self.ar_text_position = SinePositionalEmbedding( @@ -517,14 +523,61 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol): return xy_pos def compile(self, *args, **kwds): - # Experimental features to reduce compilation times, will be on by default in future - torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True - torch._inductor.config.coordinate_descent_tuning = True - torch._inductor.config.triton.unique_kernel_names = True - torch._inductor.config.fx_graph_cache = True - torch._inductor.config.triton.cudagraph_trees = True - torch._inductor.config.triton.cudagraph_support_input_mutation = True - self.h.compile(fullgraph=True, mode="reduce-overhead") + if ( + torch.cuda.is_available() + and platform.system() != "Windows" + or platform.system() == "macOS" + and self.compiled is False + ): + cache_path = Path.cwd() / "compile_cache" + if cache_path.exists() is False: + cache_path.mkdir(parents=True, exist_ok=True) + else: + assert cache_path.is_dir() + cache_file = ( + cache_path + / f"t2s_decoder_{self.n_layer}_{self.hidden_dim}_{self.n_head}_{self.ffn_dim}_{self.phoneme_vocab_size}_{get_machine_id()}_{torch.__version__}.GSV" + ) + if cache_file.exists(): + try: + with open(cache_file, "rb") as f: + cache_data = pickle.load(f) + torch.compiler.load_cache_artifacts(cache_data) + except Exception as e: + print(f"Failed to resotore compile cache from {cache_file}: {e}") + + # Experimental features to reduce compilation times, will be on by default in future + torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.triton.unique_kernel_names = True + torch._inductor.config.fx_graph_cache = True + torch._inductor.config.triton.cudagraph_trees = True + torch._inductor.config.triton.cudagraph_support_input_mutation = True + self.h.compile(fullgraph=True, mode="reduce-overhead") + self.compiled = True + + def save_compile_cache(self): + if torch.cuda.is_available() and platform.system() != "Windows" or platform.system() == "macOS": + cache_path = Path.cwd() / "compile_cache" + if cache_path.exists() is False: + cache_path.mkdir(parents=True, exist_ok=True) + else: + assert cache_path.is_dir() + cache_file = ( + cache_path + / f"t2s_decoder_{self.n_layer}_{self.hidden_dim}_{self.n_head}_{self.ffn_dim}_{self.phoneme_vocab_size}_{get_machine_id()}_{torch.__version__}.GSV" + ) + if cache_file.exists(): + return + + try: + cache = torch.compiler.save_cache_artifacts() + assert cache + cache_data = cache[0] + with open(cache_file, "wb") as f: + pickle.dump(cache_data, f) + except Exception as e: + print(f"Failed to save compile cache to {cache_file}: {e}") def capture( self, input_pos: Tensor, x: Tensor, x_dec: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds diff --git a/tools/my_utils.py b/tools/my_utils.py index 12ab7dc0..19e147a6 100644 --- a/tools/my_utils.py +++ b/tools/my_utils.py @@ -1,7 +1,10 @@ import ctypes +import hashlib import io import os +import platform import sys +import uuid from pathlib import Path from typing import IO, Union @@ -354,3 +357,11 @@ class _open_file(_opener[IO[bytes]]): def __exit__(self, *args): self.file_like.close() + + +def get_machine_id(): + mac = uuid.getnode() + hostname = platform.node() + raw = f"{mac}-{hostname}" + serial = hashlib.md5(raw.encode()).hexdigest()[:20] # 取前20位 + return serial