This commit is contained in:
XXXXRT666 2025-09-30 07:14:38 +01:00
parent 630ce73c26
commit 3b624fc243
7 changed files with 86 additions and 9 deletions

1
.gitignore vendored
View File

@ -19,6 +19,7 @@ ref_audios
tools/AP_BWE/24kto48k/*
!tools/AP_BWE/24kto48k/readme.txt
onnx_export
compile_cache
# Byte-compiled / optimized / DLL files
__pycache__/

View File

@ -114,6 +114,9 @@ class T2SDecoder(T2SDecoderABC):
self.kv_class = KVCacheNHD
def compile(self, *args, **kwds):
pass
def post_forward(self, idx: int, session: T2SSession) -> None:
return super().post_forward(idx, session)
@ -133,6 +136,8 @@ class CUDAGraphCache(CUDAGraphCacheABC):
if session.id == self.id:
self.assigned = False
else:
assert session.graph
session.graph.reset()
del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache
def get_cache_graph(self, session: T2SSession):

View File

@ -107,6 +107,9 @@ class T2SDecoder(T2SDecoderABC):
self.kv_class = KVCacheHND
def compile(self, *args, **kwds):
pass
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
return list(), dict(cu_seqlens_q=session.cu_seqlens_q, cu_seqlens_kv=session.cu_seqlens_kv)
@ -136,6 +139,8 @@ class CUDAGraphCache(CUDAGraphCacheABC):
if session.id == self.id:
self.assigned = False
else:
assert session.graph
session.graph.reset()
del (
session.graph,
session.xy_pos_,

View File

@ -130,6 +130,8 @@ class CUDAGraphCache(CUDAGraphCacheABC):
if session.id == self.id:
self.assigned = False
else:
assert session.graph
session.graph.reset()
del (
session.graph,
session.xy_pos_,

View File

@ -32,7 +32,7 @@ class T2SEngine(T2SEngineProtocol):
self.dtype = dtype
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
self.decoder_model.compile()
# self.decoder_model.compile()
self.graphcache: CUDAGraphCacheABC = self.init_cache()

View File

@ -6,9 +6,12 @@ from __future__ import annotations
import math
import os
import pickle
import platform
import random
from abc import ABC, abstractmethod
from contextlib import nullcontext
from pathlib import Path
from typing import MutableSequence
import torch
@ -17,6 +20,8 @@ import torch.nn.functional as F
from torch.cuda.graphs import CUDAGraph
from torch.profiler import ProfilerAction, tensorboard_trace_handler
from tools.my_utils import get_machine_id
from . import nn
from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession
@ -453,6 +458,7 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
self.kv_class: type[KVCacheABC]
self.GraphCache: CUDAGraphCacheABC | None
self.compiled: bool = False
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
self.ar_text_position = SinePositionalEmbedding(
@ -517,14 +523,61 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
return xy_pos
def compile(self, *args, **kwds):
# Experimental features to reduce compilation times, will be on by default in future
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True
torch._inductor.config.triton.cudagraph_trees = True
torch._inductor.config.triton.cudagraph_support_input_mutation = True
self.h.compile(fullgraph=True, mode="reduce-overhead")
if (
torch.cuda.is_available()
and platform.system() != "Windows"
or platform.system() == "macOS"
and self.compiled is False
):
cache_path = Path.cwd() / "compile_cache"
if cache_path.exists() is False:
cache_path.mkdir(parents=True, exist_ok=True)
else:
assert cache_path.is_dir()
cache_file = (
cache_path
/ f"t2s_decoder_{self.n_layer}_{self.hidden_dim}_{self.n_head}_{self.ffn_dim}_{self.phoneme_vocab_size}_{get_machine_id()}_{torch.__version__}.GSV"
)
if cache_file.exists():
try:
with open(cache_file, "rb") as f:
cache_data = pickle.load(f)
torch.compiler.load_cache_artifacts(cache_data)
except Exception as e:
print(f"Failed to resotore compile cache from {cache_file}: {e}")
# Experimental features to reduce compilation times, will be on by default in future
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True
torch._inductor.config.triton.cudagraph_trees = True
torch._inductor.config.triton.cudagraph_support_input_mutation = True
self.h.compile(fullgraph=True, mode="reduce-overhead")
self.compiled = True
def save_compile_cache(self):
if torch.cuda.is_available() and platform.system() != "Windows" or platform.system() == "macOS":
cache_path = Path.cwd() / "compile_cache"
if cache_path.exists() is False:
cache_path.mkdir(parents=True, exist_ok=True)
else:
assert cache_path.is_dir()
cache_file = (
cache_path
/ f"t2s_decoder_{self.n_layer}_{self.hidden_dim}_{self.n_head}_{self.ffn_dim}_{self.phoneme_vocab_size}_{get_machine_id()}_{torch.__version__}.GSV"
)
if cache_file.exists():
return
try:
cache = torch.compiler.save_cache_artifacts()
assert cache
cache_data = cache[0]
with open(cache_file, "wb") as f:
pickle.dump(cache_data, f)
except Exception as e:
print(f"Failed to save compile cache to {cache_file}: {e}")
def capture(
self, input_pos: Tensor, x: Tensor, x_dec: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds

View File

@ -1,7 +1,10 @@
import ctypes
import hashlib
import io
import os
import platform
import sys
import uuid
from pathlib import Path
from typing import IO, Union
@ -354,3 +357,11 @@ class _open_file(_opener[IO[bytes]]):
def __exit__(self, *args):
self.file_like.close()
def get_machine_id():
mac = uuid.getnode()
hostname = platform.node()
raw = f"{mac}-{hostname}"
serial = hashlib.md5(raw.encode()).hexdigest()[:20] # 取前20位
return serial