mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-04 05:00:01 +08:00
.
This commit is contained in:
parent
630ce73c26
commit
3b624fc243
1
.gitignore
vendored
1
.gitignore
vendored
@ -19,6 +19,7 @@ ref_audios
|
||||
tools/AP_BWE/24kto48k/*
|
||||
!tools/AP_BWE/24kto48k/readme.txt
|
||||
onnx_export
|
||||
compile_cache
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
@ -114,6 +114,9 @@ class T2SDecoder(T2SDecoderABC):
|
||||
|
||||
self.kv_class = KVCacheNHD
|
||||
|
||||
def compile(self, *args, **kwds):
|
||||
pass
|
||||
|
||||
def post_forward(self, idx: int, session: T2SSession) -> None:
|
||||
return super().post_forward(idx, session)
|
||||
|
||||
@ -133,6 +136,8 @@ class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
if session.id == self.id:
|
||||
self.assigned = False
|
||||
else:
|
||||
assert session.graph
|
||||
session.graph.reset()
|
||||
del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache
|
||||
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
|
@ -107,6 +107,9 @@ class T2SDecoder(T2SDecoderABC):
|
||||
|
||||
self.kv_class = KVCacheHND
|
||||
|
||||
def compile(self, *args, **kwds):
|
||||
pass
|
||||
|
||||
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
|
||||
return list(), dict(cu_seqlens_q=session.cu_seqlens_q, cu_seqlens_kv=session.cu_seqlens_kv)
|
||||
|
||||
@ -136,6 +139,8 @@ class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
if session.id == self.id:
|
||||
self.assigned = False
|
||||
else:
|
||||
assert session.graph
|
||||
session.graph.reset()
|
||||
del (
|
||||
session.graph,
|
||||
session.xy_pos_,
|
||||
|
@ -130,6 +130,8 @@ class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
if session.id == self.id:
|
||||
self.assigned = False
|
||||
else:
|
||||
assert session.graph
|
||||
session.graph.reset()
|
||||
del (
|
||||
session.graph,
|
||||
session.xy_pos_,
|
||||
|
@ -32,7 +32,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
self.dtype = dtype
|
||||
|
||||
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
|
||||
self.decoder_model.compile()
|
||||
# self.decoder_model.compile()
|
||||
|
||||
self.graphcache: CUDAGraphCacheABC = self.init_cache()
|
||||
|
||||
|
@ -6,9 +6,12 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import platform
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import MutableSequence
|
||||
|
||||
import torch
|
||||
@ -17,6 +20,8 @@ import torch.nn.functional as F
|
||||
from torch.cuda.graphs import CUDAGraph
|
||||
from torch.profiler import ProfilerAction, tensorboard_trace_handler
|
||||
|
||||
from tools.my_utils import get_machine_id
|
||||
|
||||
from . import nn
|
||||
from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession
|
||||
|
||||
@ -453,6 +458,7 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
|
||||
self.kv_class: type[KVCacheABC]
|
||||
|
||||
self.GraphCache: CUDAGraphCacheABC | None
|
||||
self.compiled: bool = False
|
||||
|
||||
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
|
||||
self.ar_text_position = SinePositionalEmbedding(
|
||||
@ -517,14 +523,61 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
|
||||
return xy_pos
|
||||
|
||||
def compile(self, *args, **kwds):
|
||||
# Experimental features to reduce compilation times, will be on by default in future
|
||||
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
|
||||
torch._inductor.config.coordinate_descent_tuning = True
|
||||
torch._inductor.config.triton.unique_kernel_names = True
|
||||
torch._inductor.config.fx_graph_cache = True
|
||||
torch._inductor.config.triton.cudagraph_trees = True
|
||||
torch._inductor.config.triton.cudagraph_support_input_mutation = True
|
||||
self.h.compile(fullgraph=True, mode="reduce-overhead")
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and platform.system() != "Windows"
|
||||
or platform.system() == "macOS"
|
||||
and self.compiled is False
|
||||
):
|
||||
cache_path = Path.cwd() / "compile_cache"
|
||||
if cache_path.exists() is False:
|
||||
cache_path.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
assert cache_path.is_dir()
|
||||
cache_file = (
|
||||
cache_path
|
||||
/ f"t2s_decoder_{self.n_layer}_{self.hidden_dim}_{self.n_head}_{self.ffn_dim}_{self.phoneme_vocab_size}_{get_machine_id()}_{torch.__version__}.GSV"
|
||||
)
|
||||
if cache_file.exists():
|
||||
try:
|
||||
with open(cache_file, "rb") as f:
|
||||
cache_data = pickle.load(f)
|
||||
torch.compiler.load_cache_artifacts(cache_data)
|
||||
except Exception as e:
|
||||
print(f"Failed to resotore compile cache from {cache_file}: {e}")
|
||||
|
||||
# Experimental features to reduce compilation times, will be on by default in future
|
||||
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
|
||||
torch._inductor.config.coordinate_descent_tuning = True
|
||||
torch._inductor.config.triton.unique_kernel_names = True
|
||||
torch._inductor.config.fx_graph_cache = True
|
||||
torch._inductor.config.triton.cudagraph_trees = True
|
||||
torch._inductor.config.triton.cudagraph_support_input_mutation = True
|
||||
self.h.compile(fullgraph=True, mode="reduce-overhead")
|
||||
self.compiled = True
|
||||
|
||||
def save_compile_cache(self):
|
||||
if torch.cuda.is_available() and platform.system() != "Windows" or platform.system() == "macOS":
|
||||
cache_path = Path.cwd() / "compile_cache"
|
||||
if cache_path.exists() is False:
|
||||
cache_path.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
assert cache_path.is_dir()
|
||||
cache_file = (
|
||||
cache_path
|
||||
/ f"t2s_decoder_{self.n_layer}_{self.hidden_dim}_{self.n_head}_{self.ffn_dim}_{self.phoneme_vocab_size}_{get_machine_id()}_{torch.__version__}.GSV"
|
||||
)
|
||||
if cache_file.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
cache = torch.compiler.save_cache_artifacts()
|
||||
assert cache
|
||||
cache_data = cache[0]
|
||||
with open(cache_file, "wb") as f:
|
||||
pickle.dump(cache_data, f)
|
||||
except Exception as e:
|
||||
print(f"Failed to save compile cache to {cache_file}: {e}")
|
||||
|
||||
def capture(
|
||||
self, input_pos: Tensor, x: Tensor, x_dec: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds
|
||||
|
@ -1,7 +1,10 @@
|
||||
import ctypes
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import IO, Union
|
||||
|
||||
@ -354,3 +357,11 @@ class _open_file(_opener[IO[bytes]]):
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.file_like.close()
|
||||
|
||||
|
||||
def get_machine_id():
|
||||
mac = uuid.getnode()
|
||||
hostname = platform.node()
|
||||
raw = f"{mac}-{hostname}"
|
||||
serial = hashlib.md5(raw.encode()).hexdigest()[:20] # 取前20位
|
||||
return serial
|
||||
|
Loading…
x
Reference in New Issue
Block a user