mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-04 05:00:01 +08:00
.
This commit is contained in:
parent
630ce73c26
commit
3b624fc243
1
.gitignore
vendored
1
.gitignore
vendored
@ -19,6 +19,7 @@ ref_audios
|
|||||||
tools/AP_BWE/24kto48k/*
|
tools/AP_BWE/24kto48k/*
|
||||||
!tools/AP_BWE/24kto48k/readme.txt
|
!tools/AP_BWE/24kto48k/readme.txt
|
||||||
onnx_export
|
onnx_export
|
||||||
|
compile_cache
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
@ -114,6 +114,9 @@ class T2SDecoder(T2SDecoderABC):
|
|||||||
|
|
||||||
self.kv_class = KVCacheNHD
|
self.kv_class = KVCacheNHD
|
||||||
|
|
||||||
|
def compile(self, *args, **kwds):
|
||||||
|
pass
|
||||||
|
|
||||||
def post_forward(self, idx: int, session: T2SSession) -> None:
|
def post_forward(self, idx: int, session: T2SSession) -> None:
|
||||||
return super().post_forward(idx, session)
|
return super().post_forward(idx, session)
|
||||||
|
|
||||||
@ -133,6 +136,8 @@ class CUDAGraphCache(CUDAGraphCacheABC):
|
|||||||
if session.id == self.id:
|
if session.id == self.id:
|
||||||
self.assigned = False
|
self.assigned = False
|
||||||
else:
|
else:
|
||||||
|
assert session.graph
|
||||||
|
session.graph.reset()
|
||||||
del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache
|
del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache
|
||||||
|
|
||||||
def get_cache_graph(self, session: T2SSession):
|
def get_cache_graph(self, session: T2SSession):
|
||||||
|
@ -107,6 +107,9 @@ class T2SDecoder(T2SDecoderABC):
|
|||||||
|
|
||||||
self.kv_class = KVCacheHND
|
self.kv_class = KVCacheHND
|
||||||
|
|
||||||
|
def compile(self, *args, **kwds):
|
||||||
|
pass
|
||||||
|
|
||||||
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
|
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
|
||||||
return list(), dict(cu_seqlens_q=session.cu_seqlens_q, cu_seqlens_kv=session.cu_seqlens_kv)
|
return list(), dict(cu_seqlens_q=session.cu_seqlens_q, cu_seqlens_kv=session.cu_seqlens_kv)
|
||||||
|
|
||||||
@ -136,6 +139,8 @@ class CUDAGraphCache(CUDAGraphCacheABC):
|
|||||||
if session.id == self.id:
|
if session.id == self.id:
|
||||||
self.assigned = False
|
self.assigned = False
|
||||||
else:
|
else:
|
||||||
|
assert session.graph
|
||||||
|
session.graph.reset()
|
||||||
del (
|
del (
|
||||||
session.graph,
|
session.graph,
|
||||||
session.xy_pos_,
|
session.xy_pos_,
|
||||||
|
@ -130,6 +130,8 @@ class CUDAGraphCache(CUDAGraphCacheABC):
|
|||||||
if session.id == self.id:
|
if session.id == self.id:
|
||||||
self.assigned = False
|
self.assigned = False
|
||||||
else:
|
else:
|
||||||
|
assert session.graph
|
||||||
|
session.graph.reset()
|
||||||
del (
|
del (
|
||||||
session.graph,
|
session.graph,
|
||||||
session.xy_pos_,
|
session.xy_pos_,
|
||||||
|
@ -32,7 +32,7 @@ class T2SEngine(T2SEngineProtocol):
|
|||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
|
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
|
||||||
self.decoder_model.compile()
|
# self.decoder_model.compile()
|
||||||
|
|
||||||
self.graphcache: CUDAGraphCacheABC = self.init_cache()
|
self.graphcache: CUDAGraphCacheABC = self.init_cache()
|
||||||
|
|
||||||
|
@ -6,9 +6,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
|
import platform
|
||||||
import random
|
import random
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
from pathlib import Path
|
||||||
from typing import MutableSequence
|
from typing import MutableSequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -17,6 +20,8 @@ import torch.nn.functional as F
|
|||||||
from torch.cuda.graphs import CUDAGraph
|
from torch.cuda.graphs import CUDAGraph
|
||||||
from torch.profiler import ProfilerAction, tensorboard_trace_handler
|
from torch.profiler import ProfilerAction, tensorboard_trace_handler
|
||||||
|
|
||||||
|
from tools.my_utils import get_machine_id
|
||||||
|
|
||||||
from . import nn
|
from . import nn
|
||||||
from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession
|
from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession
|
||||||
|
|
||||||
@ -453,6 +458,7 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
|
|||||||
self.kv_class: type[KVCacheABC]
|
self.kv_class: type[KVCacheABC]
|
||||||
|
|
||||||
self.GraphCache: CUDAGraphCacheABC | None
|
self.GraphCache: CUDAGraphCacheABC | None
|
||||||
|
self.compiled: bool = False
|
||||||
|
|
||||||
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
|
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
|
||||||
self.ar_text_position = SinePositionalEmbedding(
|
self.ar_text_position = SinePositionalEmbedding(
|
||||||
@ -517,14 +523,61 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
|
|||||||
return xy_pos
|
return xy_pos
|
||||||
|
|
||||||
def compile(self, *args, **kwds):
|
def compile(self, *args, **kwds):
|
||||||
# Experimental features to reduce compilation times, will be on by default in future
|
if (
|
||||||
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
|
torch.cuda.is_available()
|
||||||
torch._inductor.config.coordinate_descent_tuning = True
|
and platform.system() != "Windows"
|
||||||
torch._inductor.config.triton.unique_kernel_names = True
|
or platform.system() == "macOS"
|
||||||
torch._inductor.config.fx_graph_cache = True
|
and self.compiled is False
|
||||||
torch._inductor.config.triton.cudagraph_trees = True
|
):
|
||||||
torch._inductor.config.triton.cudagraph_support_input_mutation = True
|
cache_path = Path.cwd() / "compile_cache"
|
||||||
self.h.compile(fullgraph=True, mode="reduce-overhead")
|
if cache_path.exists() is False:
|
||||||
|
cache_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
else:
|
||||||
|
assert cache_path.is_dir()
|
||||||
|
cache_file = (
|
||||||
|
cache_path
|
||||||
|
/ f"t2s_decoder_{self.n_layer}_{self.hidden_dim}_{self.n_head}_{self.ffn_dim}_{self.phoneme_vocab_size}_{get_machine_id()}_{torch.__version__}.GSV"
|
||||||
|
)
|
||||||
|
if cache_file.exists():
|
||||||
|
try:
|
||||||
|
with open(cache_file, "rb") as f:
|
||||||
|
cache_data = pickle.load(f)
|
||||||
|
torch.compiler.load_cache_artifacts(cache_data)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to resotore compile cache from {cache_file}: {e}")
|
||||||
|
|
||||||
|
# Experimental features to reduce compilation times, will be on by default in future
|
||||||
|
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
|
||||||
|
torch._inductor.config.coordinate_descent_tuning = True
|
||||||
|
torch._inductor.config.triton.unique_kernel_names = True
|
||||||
|
torch._inductor.config.fx_graph_cache = True
|
||||||
|
torch._inductor.config.triton.cudagraph_trees = True
|
||||||
|
torch._inductor.config.triton.cudagraph_support_input_mutation = True
|
||||||
|
self.h.compile(fullgraph=True, mode="reduce-overhead")
|
||||||
|
self.compiled = True
|
||||||
|
|
||||||
|
def save_compile_cache(self):
|
||||||
|
if torch.cuda.is_available() and platform.system() != "Windows" or platform.system() == "macOS":
|
||||||
|
cache_path = Path.cwd() / "compile_cache"
|
||||||
|
if cache_path.exists() is False:
|
||||||
|
cache_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
else:
|
||||||
|
assert cache_path.is_dir()
|
||||||
|
cache_file = (
|
||||||
|
cache_path
|
||||||
|
/ f"t2s_decoder_{self.n_layer}_{self.hidden_dim}_{self.n_head}_{self.ffn_dim}_{self.phoneme_vocab_size}_{get_machine_id()}_{torch.__version__}.GSV"
|
||||||
|
)
|
||||||
|
if cache_file.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
cache = torch.compiler.save_cache_artifacts()
|
||||||
|
assert cache
|
||||||
|
cache_data = cache[0]
|
||||||
|
with open(cache_file, "wb") as f:
|
||||||
|
pickle.dump(cache_data, f)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to save compile cache to {cache_file}: {e}")
|
||||||
|
|
||||||
def capture(
|
def capture(
|
||||||
self, input_pos: Tensor, x: Tensor, x_dec: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds
|
self, input_pos: Tensor, x: Tensor, x_dec: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
import ctypes
|
import ctypes
|
||||||
|
import hashlib
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import sys
|
import sys
|
||||||
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import IO, Union
|
from typing import IO, Union
|
||||||
|
|
||||||
@ -354,3 +357,11 @@ class _open_file(_opener[IO[bytes]]):
|
|||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
self.file_like.close()
|
self.file_like.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_machine_id():
|
||||||
|
mac = uuid.getnode()
|
||||||
|
hostname = platform.node()
|
||||||
|
raw = f"{mac}-{hostname}"
|
||||||
|
serial = hashlib.md5(raw.encode()).hexdigest()[:20] # 取前20位
|
||||||
|
return serial
|
||||||
|
Loading…
x
Reference in New Issue
Block a user