This commit is contained in:
XXXXRT666 2025-09-30 07:14:38 +01:00
parent 630ce73c26
commit 3b624fc243
7 changed files with 86 additions and 9 deletions

1
.gitignore vendored
View File

@ -19,6 +19,7 @@ ref_audios
tools/AP_BWE/24kto48k/* tools/AP_BWE/24kto48k/*
!tools/AP_BWE/24kto48k/readme.txt !tools/AP_BWE/24kto48k/readme.txt
onnx_export onnx_export
compile_cache
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/

View File

@ -114,6 +114,9 @@ class T2SDecoder(T2SDecoderABC):
self.kv_class = KVCacheNHD self.kv_class = KVCacheNHD
def compile(self, *args, **kwds):
pass
def post_forward(self, idx: int, session: T2SSession) -> None: def post_forward(self, idx: int, session: T2SSession) -> None:
return super().post_forward(idx, session) return super().post_forward(idx, session)
@ -133,6 +136,8 @@ class CUDAGraphCache(CUDAGraphCacheABC):
if session.id == self.id: if session.id == self.id:
self.assigned = False self.assigned = False
else: else:
assert session.graph
session.graph.reset()
del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache
def get_cache_graph(self, session: T2SSession): def get_cache_graph(self, session: T2SSession):

View File

@ -107,6 +107,9 @@ class T2SDecoder(T2SDecoderABC):
self.kv_class = KVCacheHND self.kv_class = KVCacheHND
def compile(self, *args, **kwds):
pass
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]: def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
return list(), dict(cu_seqlens_q=session.cu_seqlens_q, cu_seqlens_kv=session.cu_seqlens_kv) return list(), dict(cu_seqlens_q=session.cu_seqlens_q, cu_seqlens_kv=session.cu_seqlens_kv)
@ -136,6 +139,8 @@ class CUDAGraphCache(CUDAGraphCacheABC):
if session.id == self.id: if session.id == self.id:
self.assigned = False self.assigned = False
else: else:
assert session.graph
session.graph.reset()
del ( del (
session.graph, session.graph,
session.xy_pos_, session.xy_pos_,

View File

@ -130,6 +130,8 @@ class CUDAGraphCache(CUDAGraphCacheABC):
if session.id == self.id: if session.id == self.id:
self.assigned = False self.assigned = False
else: else:
assert session.graph
session.graph.reset()
del ( del (
session.graph, session.graph,
session.xy_pos_, session.xy_pos_,

View File

@ -32,7 +32,7 @@ class T2SEngine(T2SEngineProtocol):
self.dtype = dtype self.dtype = dtype
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype) self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
self.decoder_model.compile() # self.decoder_model.compile()
self.graphcache: CUDAGraphCacheABC = self.init_cache() self.graphcache: CUDAGraphCacheABC = self.init_cache()

View File

@ -6,9 +6,12 @@ from __future__ import annotations
import math import math
import os import os
import pickle
import platform
import random import random
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path
from typing import MutableSequence from typing import MutableSequence
import torch import torch
@ -17,6 +20,8 @@ import torch.nn.functional as F
from torch.cuda.graphs import CUDAGraph from torch.cuda.graphs import CUDAGraph
from torch.profiler import ProfilerAction, tensorboard_trace_handler from torch.profiler import ProfilerAction, tensorboard_trace_handler
from tools.my_utils import get_machine_id
from . import nn from . import nn
from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession
@ -453,6 +458,7 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
self.kv_class: type[KVCacheABC] self.kv_class: type[KVCacheABC]
self.GraphCache: CUDAGraphCacheABC | None self.GraphCache: CUDAGraphCacheABC | None
self.compiled: bool = False
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size) self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
self.ar_text_position = SinePositionalEmbedding( self.ar_text_position = SinePositionalEmbedding(
@ -517,6 +523,29 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
return xy_pos return xy_pos
def compile(self, *args, **kwds): def compile(self, *args, **kwds):
if (
torch.cuda.is_available()
and platform.system() != "Windows"
or platform.system() == "macOS"
and self.compiled is False
):
cache_path = Path.cwd() / "compile_cache"
if cache_path.exists() is False:
cache_path.mkdir(parents=True, exist_ok=True)
else:
assert cache_path.is_dir()
cache_file = (
cache_path
/ f"t2s_decoder_{self.n_layer}_{self.hidden_dim}_{self.n_head}_{self.ffn_dim}_{self.phoneme_vocab_size}_{get_machine_id()}_{torch.__version__}.GSV"
)
if cache_file.exists():
try:
with open(cache_file, "rb") as f:
cache_data = pickle.load(f)
torch.compiler.load_cache_artifacts(cache_data)
except Exception as e:
print(f"Failed to resotore compile cache from {cache_file}: {e}")
# Experimental features to reduce compilation times, will be on by default in future # Experimental features to reduce compilation times, will be on by default in future
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.coordinate_descent_tuning = True
@ -525,6 +554,30 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
torch._inductor.config.triton.cudagraph_trees = True torch._inductor.config.triton.cudagraph_trees = True
torch._inductor.config.triton.cudagraph_support_input_mutation = True torch._inductor.config.triton.cudagraph_support_input_mutation = True
self.h.compile(fullgraph=True, mode="reduce-overhead") self.h.compile(fullgraph=True, mode="reduce-overhead")
self.compiled = True
def save_compile_cache(self):
if torch.cuda.is_available() and platform.system() != "Windows" or platform.system() == "macOS":
cache_path = Path.cwd() / "compile_cache"
if cache_path.exists() is False:
cache_path.mkdir(parents=True, exist_ok=True)
else:
assert cache_path.is_dir()
cache_file = (
cache_path
/ f"t2s_decoder_{self.n_layer}_{self.hidden_dim}_{self.n_head}_{self.ffn_dim}_{self.phoneme_vocab_size}_{get_machine_id()}_{torch.__version__}.GSV"
)
if cache_file.exists():
return
try:
cache = torch.compiler.save_cache_artifacts()
assert cache
cache_data = cache[0]
with open(cache_file, "wb") as f:
pickle.dump(cache_data, f)
except Exception as e:
print(f"Failed to save compile cache to {cache_file}: {e}")
def capture( def capture(
self, input_pos: Tensor, x: Tensor, x_dec: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds self, input_pos: Tensor, x: Tensor, x_dec: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds

View File

@ -1,7 +1,10 @@
import ctypes import ctypes
import hashlib
import io import io
import os import os
import platform
import sys import sys
import uuid
from pathlib import Path from pathlib import Path
from typing import IO, Union from typing import IO, Union
@ -354,3 +357,11 @@ class _open_file(_opener[IO[bytes]]):
def __exit__(self, *args): def __exit__(self, *args):
self.file_like.close() self.file_like.close()
def get_machine_id():
mac = uuid.getnode()
hostname = platform.node()
raw = f"{mac}-{hostname}"
serial = hashlib.md5(raw.encode()).hexdigest()[:20] # 取前20位
return serial