mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 08:49:59 +08:00
672 lines
22 KiB
Python
672 lines
22 KiB
Python
"""
|
|
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
import os
|
|
import random
|
|
from abc import ABC, abstractmethod
|
|
from contextlib import nullcontext
|
|
from typing import MutableSequence
|
|
|
|
import torch
|
|
import torch._inductor.config
|
|
import torch.nn.functional as F
|
|
from torch.cuda.graphs import CUDAGraph
|
|
from torch.profiler import ProfilerAction, tensorboard_trace_handler
|
|
|
|
from . import nn
|
|
from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession
|
|
|
|
Tensor = torch.Tensor
|
|
|
|
|
|
class TokenEmbedding(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embedding_dim: int,
|
|
vocab_size: int,
|
|
):
|
|
super().__init__()
|
|
|
|
self.vocab_size = vocab_size
|
|
self.embedding_dim = embedding_dim
|
|
|
|
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
|
|
|
|
@property
|
|
def weight(self) -> Tensor:
|
|
return self.word_embeddings.weight
|
|
|
|
def embedding(self, index: int) -> Tensor:
|
|
return self.word_embeddings.weight[index : index + 1]
|
|
|
|
def __call__(self, x: Tensor):
|
|
x = self.word_embeddings(x)
|
|
return x
|
|
|
|
|
|
class SinePositionalEmbedding(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embedding_dim: int,
|
|
scale: bool = False,
|
|
alpha: bool = False,
|
|
max_batch_size: int = 10,
|
|
max_seq_len: int = 1800,
|
|
):
|
|
super().__init__()
|
|
self.embedding_dim = embedding_dim
|
|
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
|
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
|
self.max_batch_size = max_batch_size
|
|
self.max_seq_len = max_seq_len
|
|
|
|
self.reverse = False
|
|
self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False)
|
|
self.pe: torch.Tensor
|
|
self.compute_pe()
|
|
|
|
def compute_pe(self):
|
|
"""Reset the positional encodings."""
|
|
if self.reverse:
|
|
position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
|
else:
|
|
position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
|
|
div_term = torch.exp(
|
|
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
|
|
)
|
|
pe = self.pe
|
|
pe[:, :, 0::2] = torch.sin(position * div_term)
|
|
pe[:, :, 1::2] = torch.cos(position * div_term)
|
|
|
|
def __call__(self, input_pos: Tensor, x: Tensor) -> Tensor:
|
|
"""
|
|
Args:
|
|
input_pos (Tensor): [batch_size, ]
|
|
x (Tensor): [batch_size, 1, embed_dim]
|
|
|
|
Returns:
|
|
embedded_x (Tensor): [batch_size, 1, embed_dim]
|
|
"""
|
|
|
|
batch_size = x.shape[0]
|
|
pe_values = self.pe[torch.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
|
|
|
|
return x * self.x_scale + self.alpha * pe_values.unsqueeze(1) # (batch_size, 1, embed_dim)
|
|
|
|
def prefill(self, x: Tensor) -> Tensor:
|
|
"""
|
|
Args:
|
|
x (Tensor): [batch_size, seq_len, embed_dim]
|
|
|
|
Returns:
|
|
embedded_x (Tensor): [batch_size, seq_len, embed_dim]
|
|
"""
|
|
|
|
pe_values = self.pe[:, : x.shape[-2]]
|
|
return x * self.x_scale + self.alpha.item() * pe_values
|
|
|
|
|
|
class KVCacheABC(nn.Module, ABC, KVCacheProtocol):
|
|
def __init__(self, batch_size: int, max_seq_length: int, n_heads: int, head_dim: int) -> None:
|
|
super().__init__()
|
|
|
|
self.n_head = n_heads
|
|
self.head_dim = head_dim
|
|
self.batch_size = batch_size
|
|
self.max_seq_length = max_seq_length
|
|
|
|
self.k_cache: Tensor
|
|
self.v_cache: Tensor
|
|
|
|
def empty(self):
|
|
self.k_cache.zero_()
|
|
self.v_cache.zero_()
|
|
|
|
@abstractmethod
|
|
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> tuple[Tensor, Tensor]: ...
|
|
|
|
@abstractmethod
|
|
def prefill_kv(self, k_val: Tensor, v_val: Tensor) -> None: ...
|
|
|
|
def sync_cache(self, kv_cache: KVCacheProtocol):
|
|
self.k_cache.copy_(kv_cache.k_cache)
|
|
self.v_cache.copy_(kv_cache.v_cache)
|
|
|
|
|
|
class KVCacheNHD(KVCacheABC):
|
|
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
|
|
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
|
|
|
|
assert batch_size > 0
|
|
cache_shape = (batch_size, max_seq_length, n_heads, head_dim)
|
|
|
|
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
|
|
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
|
|
|
|
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
|
# input_pos: [B, ], k_val: [B, 1, H, D]
|
|
|
|
index = (
|
|
(input_pos - 1)
|
|
.unsqueeze(-1)
|
|
.unsqueeze(-1)
|
|
.unsqueeze(-1)
|
|
.expand(
|
|
-1,
|
|
-1,
|
|
self.n_head,
|
|
self.head_dim,
|
|
)
|
|
.to(torch.int64)
|
|
) # (bs, 1, num_head, head_dim)
|
|
|
|
k_out = self.k_cache
|
|
v_out = self.v_cache
|
|
k_out.scatter_(1, index, k_val)
|
|
v_out.scatter_(1, index, v_val)
|
|
|
|
return k_out, v_out
|
|
|
|
def empty(self):
|
|
self.k_cache.zero_()
|
|
self.v_cache.zero_()
|
|
|
|
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
|
|
# input_pos: int, k_val: [B, S, H, D]
|
|
|
|
self.k_cache[:, : k_val.shape[1]] = k_val
|
|
self.v_cache[:, : v_val.shape[1]] = v_val
|
|
|
|
|
|
class KVCacheHND(KVCacheABC):
|
|
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
|
|
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
|
|
|
|
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
|
|
|
|
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
|
|
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
|
|
|
|
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
|
# input_pos: [B, ], k_val: [B, H, 1, D]
|
|
|
|
index = (
|
|
(input_pos - 1)
|
|
.unsqueeze(-1)
|
|
.unsqueeze(-1)
|
|
.unsqueeze(-1)
|
|
.expand(
|
|
-1,
|
|
self.n_head,
|
|
-1,
|
|
self.head_dim,
|
|
)
|
|
.to(torch.int64)
|
|
) # (bs, num_head, 1, head_dim)
|
|
|
|
k_out = self.k_cache
|
|
v_out = self.v_cache
|
|
k_out.scatter_(2, index, k_val)
|
|
v_out.scatter_(2, index, v_val)
|
|
|
|
return k_out, v_out
|
|
|
|
def empty(self):
|
|
self.k_cache.zero_()
|
|
self.v_cache.zero_()
|
|
|
|
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
|
|
# input_pos: int, k_val: [B, S, H, D]
|
|
|
|
self.k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
|
|
self.v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
|
|
|
|
|
|
class KVCacheHNDVarlen(KVCacheABC):
|
|
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
|
|
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
|
|
|
|
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
|
|
self.cache_idx: Tensor
|
|
|
|
self.register_buffer("cache_idx", torch.arange(batch_size), persistent=False)
|
|
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
|
|
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
|
|
|
|
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
|
# input_pos: [B, ], k_val: [B, H, 1, D]
|
|
|
|
k_out = self.k_cache
|
|
v_out = self.v_cache
|
|
|
|
ip0 = input_pos - 1
|
|
|
|
k_out[self.cache_idx, :, ip0, None] = k_val
|
|
v_out[self.cache_idx, :, ip0, None] = v_val
|
|
|
|
return k_out, v_out
|
|
|
|
def empty(self):
|
|
self.k_cache.zero_()
|
|
self.v_cache.zero_()
|
|
|
|
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
|
|
# input_pos: int, k_val: [B, S, H, D]
|
|
|
|
self.k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
|
|
self.v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
|
|
|
|
|
|
class AttentionABC(nn.Module, ABC):
|
|
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
|
|
super().__init__()
|
|
|
|
self.n_head = n_head
|
|
self.hidden_dim = hidden_dim
|
|
assert hidden_dim % n_head == 0
|
|
self.head_dim = hidden_dim // n_head
|
|
|
|
self.max_seq_length = max_seq_length
|
|
|
|
# key, query, value projections for all heads, but in a batch
|
|
self.in_proj: nn.Linear
|
|
self.out_proj: nn.Linear
|
|
|
|
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
|
|
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
|
|
keys_to_modify = [key for key in state_dict if "in_proj_" in key]
|
|
for key in keys_to_modify:
|
|
new_key = key.replace("in_proj_", "in_proj.") # in_proj_ -> in_proj.
|
|
state_dict[new_key] = state_dict.pop(key)
|
|
|
|
@abstractmethod
|
|
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds) -> Tensor: ...
|
|
|
|
def prefill(self, x: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor) -> Tensor:
|
|
bsz, seqlen, _ = x.shape
|
|
|
|
q, k, v = self.in_proj(x.unsqueeze(0)).chunk(3, dim=-1)
|
|
|
|
q, k, v = map(lambda x: x.contiguous().view(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
|
|
|
kv_cache.prefill_kv(k, v)
|
|
|
|
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
|
|
|
attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
|
|
|
|
attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim)
|
|
|
|
output = self.out_proj(attn)
|
|
|
|
return output
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, dim: int, hidden_dim: int) -> None:
|
|
super().__init__()
|
|
|
|
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
|
|
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
|
|
|
|
def __call__(self, x: Tensor):
|
|
return self.linear2(F.relu(self.linear1(x), inplace=True))
|
|
|
|
|
|
class TransformerBlockABC(nn.Module, ABC):
|
|
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
|
super().__init__()
|
|
|
|
self.hidden_dim = hidden_dim
|
|
self.max_seq_length = max_seq_length
|
|
|
|
self.attention: AttentionABC
|
|
self.feed_forward: FeedForward
|
|
self.attention_norm: nn.LayerNorm
|
|
self.ffn_norm: nn.LayerNorm
|
|
|
|
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
|
|
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
|
|
for key in list(state_dict.keys()):
|
|
new_key = (
|
|
key.replace("self_attn", "attention")
|
|
.replace("linear", "feed_forward.linear")
|
|
.replace("norm1", "attention_norm")
|
|
.replace("norm2", "ffn_norm")
|
|
)
|
|
state_dict[new_key] = state_dict.pop(key)
|
|
|
|
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds):
|
|
h = self.attention_norm(
|
|
x
|
|
+ self.attention(
|
|
x,
|
|
input_pos,
|
|
kv_cache,
|
|
*args,
|
|
**kwds,
|
|
)
|
|
)
|
|
out = self.ffn_norm(h + self.feed_forward(h))
|
|
return out
|
|
|
|
def prefill(
|
|
self,
|
|
x: Tensor,
|
|
kv_cache: KVCacheProtocol,
|
|
attn_mask: Tensor,
|
|
) -> Tensor:
|
|
h = self.attention_norm(
|
|
x
|
|
+ self.attention.prefill(
|
|
x,
|
|
kv_cache,
|
|
attn_mask,
|
|
)
|
|
)
|
|
out = self.ffn_norm(h + self.feed_forward(h))
|
|
return out
|
|
|
|
|
|
class TransformerDecoderABC(nn.Module, ABC):
|
|
def __init__(
|
|
self,
|
|
hidden_dim: int,
|
|
n_layer: int,
|
|
n_head: int,
|
|
ffn_dim: int,
|
|
vocab_size: int,
|
|
max_seq_length: int,
|
|
max_batch_size: int,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.hidden_dim = hidden_dim
|
|
self.n_head = n_head
|
|
assert hidden_dim % n_head == 0
|
|
|
|
self.head_dim = hidden_dim // n_head
|
|
self.vocab_size = vocab_size
|
|
|
|
self.n_layer = n_layer
|
|
|
|
self.layers: MutableSequence[TransformerBlockABC]
|
|
|
|
self.max_seq_length = max_seq_length
|
|
self.max_batch_size = max_batch_size
|
|
|
|
def __call__(self, input_pos: Tensor, x: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds):
|
|
for layer, kv_cache in zip(self.layers, kv_caches):
|
|
x = layer(x, input_pos, kv_cache, *args, **kwds)
|
|
return x
|
|
|
|
def prefill(self, x: Tensor, kv_caches: MutableSequence[KVCacheProtocol], attn_mask: Tensor):
|
|
for layer, kv_cache in zip(self.layers, kv_caches):
|
|
x = layer.prefill(x, kv_cache, attn_mask)
|
|
return x
|
|
|
|
|
|
class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
|
|
def __init__(
|
|
self,
|
|
config: dict,
|
|
max_seq_length: int = 1800,
|
|
max_batch_size: int = 10,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
hidden_dim: int = config["model"]["hidden_dim"]
|
|
embedding_dim: int = config["model"]["embedding_dim"]
|
|
n_head: int = config["model"]["head"]
|
|
n_layer: int = config["model"]["n_layer"]
|
|
vocab_size: int = config["model"]["vocab_size"]
|
|
phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
|
|
EOS: int = config["model"]["EOS"]
|
|
ffn_dim: int = hidden_dim * 4
|
|
|
|
self.n_layer = int(n_layer)
|
|
self.hidden_dim = int(hidden_dim)
|
|
self.n_head = int(n_head)
|
|
assert hidden_dim % n_head == 0
|
|
|
|
self.head_dim = int(hidden_dim // n_head)
|
|
self.embedding_dim = int(embedding_dim)
|
|
self.ffn_dim = int(ffn_dim)
|
|
self.vocab_size = int(vocab_size)
|
|
self.phoneme_vocab_size = int(phoneme_vocab_size)
|
|
self.max_seq_length = max_seq_length
|
|
self.max_batch_size = max_batch_size
|
|
self.EOS = EOS
|
|
assert self.EOS == self.vocab_size - 1
|
|
|
|
self.bert_proj: nn.Linear
|
|
self.ar_predict_layer: nn.Linear
|
|
self.h: TransformerDecoderABC
|
|
|
|
self.kv_class: type[KVCacheABC]
|
|
|
|
self.GraphCache: CUDAGraphCacheABC | None
|
|
|
|
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
|
|
self.ar_text_position = SinePositionalEmbedding(
|
|
self.embedding_dim,
|
|
scale=False,
|
|
alpha=True,
|
|
max_batch_size=max_batch_size,
|
|
max_seq_len=max_seq_length,
|
|
)
|
|
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
|
|
self.ar_audio_position = SinePositionalEmbedding(
|
|
self.embedding_dim,
|
|
scale=False,
|
|
alpha=True,
|
|
max_batch_size=max_batch_size,
|
|
max_seq_len=max_seq_length,
|
|
)
|
|
|
|
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
|
|
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
|
|
model_keys = [key for key in state_dict if key.startswith("model.")]
|
|
for key in model_keys:
|
|
new_key = key[len("model.") :]
|
|
state_dict[new_key] = state_dict.pop(key)
|
|
|
|
def init_cache(self, bsz: int = 0) -> MutableSequence[KVCacheProtocol]:
|
|
bsz = bsz or self.h.max_batch_size
|
|
assert bsz <= self.h.max_batch_size
|
|
seq_lens = self.h.max_seq_length
|
|
dtype = self.bert_proj.bias.dtype
|
|
kvclass = self.kv_class
|
|
|
|
return nn.ModuleList(
|
|
[kvclass(bsz, seq_lens, self.n_head, self.head_dim) for _ in range(self.n_layer)],
|
|
).to(self.device, dtype) # type: ignore
|
|
|
|
def embed(
|
|
self,
|
|
x: list[torch.Tensor],
|
|
y: torch.Tensor,
|
|
bert_features: list[torch.Tensor],
|
|
):
|
|
x_len: list[int] = [i.shape[0] for i in x]
|
|
x_len_max = max(x_len)
|
|
xy_pos = torch.zeros((len(x), x_len_max + y.shape[1], self.embedding_dim)).to(bert_features[0].dtype)
|
|
|
|
bert_features = list(map(lambda x: x.transpose(0, 1), bert_features))
|
|
|
|
y_len = y.shape[1]
|
|
y_emb = self.ar_audio_embedding(y)
|
|
y_pos = self.ar_audio_position.prefill(y_emb)
|
|
|
|
for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
|
|
x_emb = self.ar_text_embedding(x_)
|
|
bert = self.bert_proj(bert_feature)
|
|
x_emb = x_emb + bert
|
|
x_pos = self.ar_text_position.prefill(x_emb.unsqueeze(0))
|
|
xy_pos[[bs], :len_] = x_pos
|
|
xy_pos[[bs], len_ : len_ + y_len] = y_pos
|
|
|
|
return xy_pos
|
|
|
|
def compile(self, *args, **kwds):
|
|
# Experimental features to reduce compilation times, will be on by default in future
|
|
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
|
|
torch._inductor.config.coordinate_descent_tuning = True
|
|
torch._inductor.config.triton.unique_kernel_names = True
|
|
torch._inductor.config.fx_graph_cache = True
|
|
torch._inductor.config.triton.cudagraph_trees = True
|
|
torch._inductor.config.triton.cudagraph_support_input_mutation = True
|
|
self.h.compile(fullgraph=True, mode="reduce-overhead")
|
|
|
|
def capture(
|
|
self, input_pos: Tensor, x: Tensor, x_dec: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds
|
|
) -> CUDAGraph:
|
|
assert torch.cuda.is_available()
|
|
s = torch.cuda.Stream()
|
|
s.wait_stream(torch.cuda.current_stream())
|
|
|
|
graph = torch.cuda.CUDAGraph()
|
|
|
|
with torch.cuda.stream(s):
|
|
for _ in range(5):
|
|
self.h(input_pos, x, kv_caches, *args, **kwds)
|
|
torch.cuda.current_stream().wait_stream(s)
|
|
|
|
with torch.cuda.graph(graph):
|
|
x_dec.copy_(self.h(input_pos, x, kv_caches, *args, **kwds))
|
|
torch.cuda.synchronize()
|
|
|
|
return graph
|
|
|
|
@abstractmethod
|
|
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
|
|
return list(), dict()
|
|
|
|
@abstractmethod
|
|
def post_forward(self, idx: int, session: T2SSession) -> None:
|
|
return
|
|
|
|
|
|
class CUDAGraphCacheABC(ABC):
|
|
def __init__(
|
|
self,
|
|
decoder: T2SDecoderABC,
|
|
) -> None:
|
|
self.is_applicable: bool
|
|
|
|
if torch.cuda.is_available() and self.is_applicable:
|
|
self.device: torch.device = decoder.device
|
|
self.dtype = decoder.bert_proj.bias.dtype
|
|
|
|
self.assigned: bool = False
|
|
|
|
self.decoder: T2SDecoderABC = decoder
|
|
self.kv_cache: MutableSequence[KVCacheProtocol] = decoder.init_cache(decoder.max_batch_size)
|
|
self.xy_pos = torch.rand(size=(decoder.max_batch_size, 1, decoder.embedding_dim), device=self.device).to(
|
|
self.dtype
|
|
)
|
|
self.xy_dec = self.xy_pos.clone()
|
|
|
|
self.input_pos = torch.tensor([10] * decoder.max_batch_size, device=self.device).int()
|
|
self.graph: torch.cuda.CUDAGraph | None = None
|
|
self.stream: torch.cuda.Stream | None
|
|
|
|
self.id: int = random.randint(1, 2**32 - 1)
|
|
|
|
def assign_graph(self, session: T2SSession):
|
|
if self.graph is None:
|
|
args, kwds = self.decoder.pre_forward(session)
|
|
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
|
self.graph = graph
|
|
self.stream = torch.cuda.Stream()
|
|
|
|
if self.assigned is False:
|
|
self.get_cache_graph(session)
|
|
session.id = self.id
|
|
self.assigned = True
|
|
else:
|
|
self.capture_new_graph(session)
|
|
|
|
@abstractmethod
|
|
def release_graph(self, session: T2SSession): ...
|
|
|
|
@abstractmethod
|
|
def get_cache_graph(self, session: T2SSession):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def capture_new_graph(self, session: T2SSession):
|
|
pass
|
|
|
|
|
|
class TorchProfiler:
|
|
def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
|
|
self.debug = debug
|
|
self.log_dir = log_dir
|
|
self.__profiler: torch.profiler.profile
|
|
|
|
if self.debug and not os.path.exists(self.log_dir):
|
|
os.makedirs(self.log_dir)
|
|
|
|
self.tensorboard_handler = tensorboard_trace_handler(self.log_dir)
|
|
|
|
def profiler_callback(self, prof: torch.profiler.profile):
|
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
|
|
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))
|
|
self.tensorboard_handler(prof)
|
|
|
|
@staticmethod
|
|
def three_step_schedule(step: int) -> ProfilerAction:
|
|
if step == 0:
|
|
return ProfilerAction.NONE
|
|
elif step == 1:
|
|
return ProfilerAction.RECORD
|
|
elif step == 2:
|
|
return ProfilerAction.RECORD_AND_SAVE
|
|
else:
|
|
return ProfilerAction.NONE
|
|
|
|
def start(self):
|
|
if not self.debug:
|
|
return
|
|
assert self.__profiler is not None
|
|
self.__profiler.step()
|
|
|
|
def end(self):
|
|
if not self.debug:
|
|
return
|
|
assert self.__profiler is not None
|
|
self.__profiler.step()
|
|
|
|
def profiler(self):
|
|
if self.debug:
|
|
activities_list = [torch.profiler.ProfilerActivity.CPU]
|
|
if torch.cuda.is_available():
|
|
activities_list.append(torch.profiler.ProfilerActivity.CUDA)
|
|
|
|
self.__profiler = torch.profiler.profile(
|
|
activities=activities_list,
|
|
record_shapes=True,
|
|
with_stack=True,
|
|
with_modules=True,
|
|
profile_memory=True,
|
|
schedule=self.three_step_schedule,
|
|
on_trace_ready=self.profiler_callback,
|
|
)
|
|
return self.__profiler
|
|
else:
|
|
return nullcontext()
|
|
|
|
def record(self, func_name: str):
|
|
if self.debug:
|
|
return torch.profiler.record_function(func_name)
|
|
else:
|
|
return nullcontext()
|