2025-09-06 22:58:58 +08:00

672 lines
22 KiB
Python

"""
Modified From https://github.com/XXXXRT666/GPT-SoVITS
"""
from __future__ import annotations
import math
import os
import random
from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import MutableSequence
import torch
import torch._inductor.config
import torch.nn.functional as F
from torch.cuda.graphs import CUDAGraph
from torch.profiler import ProfilerAction, tensorboard_trace_handler
from . import nn
from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession
Tensor = torch.Tensor
class TokenEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
vocab_size: int,
):
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
@property
def weight(self) -> Tensor:
return self.word_embeddings.weight
def embedding(self, index: int) -> Tensor:
return self.word_embeddings.weight[index : index + 1]
def __call__(self, x: Tensor):
x = self.word_embeddings(x)
return x
class SinePositionalEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
scale: bool = False,
alpha: bool = False,
max_batch_size: int = 10,
max_seq_len: int = 1800,
):
super().__init__()
self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.reverse = False
self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False)
self.pe: torch.Tensor
self.compute_pe()
def compute_pe(self):
"""Reset the positional encodings."""
if self.reverse:
position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
else:
position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
)
pe = self.pe
pe[:, :, 0::2] = torch.sin(position * div_term)
pe[:, :, 1::2] = torch.cos(position * div_term)
def __call__(self, input_pos: Tensor, x: Tensor) -> Tensor:
"""
Args:
input_pos (Tensor): [batch_size, ]
x (Tensor): [batch_size, 1, embed_dim]
Returns:
embedded_x (Tensor): [batch_size, 1, embed_dim]
"""
batch_size = x.shape[0]
pe_values = self.pe[torch.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
return x * self.x_scale + self.alpha * pe_values.unsqueeze(1) # (batch_size, 1, embed_dim)
def prefill(self, x: Tensor) -> Tensor:
"""
Args:
x (Tensor): [batch_size, seq_len, embed_dim]
Returns:
embedded_x (Tensor): [batch_size, seq_len, embed_dim]
"""
pe_values = self.pe[:, : x.shape[-2]]
return x * self.x_scale + self.alpha.item() * pe_values
class KVCacheABC(nn.Module, ABC, KVCacheProtocol):
def __init__(self, batch_size: int, max_seq_length: int, n_heads: int, head_dim: int) -> None:
super().__init__()
self.n_head = n_heads
self.head_dim = head_dim
self.batch_size = batch_size
self.max_seq_length = max_seq_length
self.k_cache: Tensor
self.v_cache: Tensor
def empty(self):
self.k_cache.zero_()
self.v_cache.zero_()
@abstractmethod
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> tuple[Tensor, Tensor]: ...
@abstractmethod
def prefill_kv(self, k_val: Tensor, v_val: Tensor) -> None: ...
def sync_cache(self, kv_cache: KVCacheProtocol):
self.k_cache.copy_(kv_cache.k_cache)
self.v_cache.copy_(kv_cache.v_cache)
class KVCacheNHD(KVCacheABC):
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
assert batch_size > 0
cache_shape = (batch_size, max_seq_length, n_heads, head_dim)
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
# input_pos: [B, ], k_val: [B, 1, H, D]
index = (
(input_pos - 1)
.unsqueeze(-1)
.unsqueeze(-1)
.unsqueeze(-1)
.expand(
-1,
-1,
self.n_head,
self.head_dim,
)
.to(torch.int64)
) # (bs, 1, num_head, head_dim)
k_out = self.k_cache
v_out = self.v_cache
k_out.scatter_(1, index, k_val)
v_out.scatter_(1, index, v_val)
return k_out, v_out
def empty(self):
self.k_cache.zero_()
self.v_cache.zero_()
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
# input_pos: int, k_val: [B, S, H, D]
self.k_cache[:, : k_val.shape[1]] = k_val
self.v_cache[:, : v_val.shape[1]] = v_val
class KVCacheHND(KVCacheABC):
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
# input_pos: [B, ], k_val: [B, H, 1, D]
index = (
(input_pos - 1)
.unsqueeze(-1)
.unsqueeze(-1)
.unsqueeze(-1)
.expand(
-1,
self.n_head,
-1,
self.head_dim,
)
.to(torch.int64)
) # (bs, num_head, 1, head_dim)
k_out = self.k_cache
v_out = self.v_cache
k_out.scatter_(2, index, k_val)
v_out.scatter_(2, index, v_val)
return k_out, v_out
def empty(self):
self.k_cache.zero_()
self.v_cache.zero_()
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
# input_pos: int, k_val: [B, S, H, D]
self.k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
self.v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
class KVCacheHNDVarlen(KVCacheABC):
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
self.cache_idx: Tensor
self.register_buffer("cache_idx", torch.arange(batch_size), persistent=False)
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
# input_pos: [B, ], k_val: [B, H, 1, D]
k_out = self.k_cache
v_out = self.v_cache
ip0 = input_pos - 1
k_out[self.cache_idx, :, ip0, None] = k_val
v_out[self.cache_idx, :, ip0, None] = v_val
return k_out, v_out
def empty(self):
self.k_cache.zero_()
self.v_cache.zero_()
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
# input_pos: int, k_val: [B, S, H, D]
self.k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
self.v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
class AttentionABC(nn.Module, ABC):
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
super().__init__()
self.n_head = n_head
self.hidden_dim = hidden_dim
assert hidden_dim % n_head == 0
self.head_dim = hidden_dim // n_head
self.max_seq_length = max_seq_length
# key, query, value projections for all heads, but in a batch
self.in_proj: nn.Linear
self.out_proj: nn.Linear
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
keys_to_modify = [key for key in state_dict if "in_proj_" in key]
for key in keys_to_modify:
new_key = key.replace("in_proj_", "in_proj.") # in_proj_ -> in_proj.
state_dict[new_key] = state_dict.pop(key)
@abstractmethod
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds) -> Tensor: ...
def prefill(self, x: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor) -> Tensor:
bsz, seqlen, _ = x.shape
q, k, v = self.in_proj(x.unsqueeze(0)).chunk(3, dim=-1)
q, k, v = map(lambda x: x.contiguous().view(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
kv_cache.prefill_kv(k, v)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim)
output = self.out_proj(attn)
return output
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int) -> None:
super().__init__()
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
def __call__(self, x: Tensor):
return self.linear2(F.relu(self.linear1(x), inplace=True))
class TransformerBlockABC(nn.Module, ABC):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.max_seq_length = max_seq_length
self.attention: AttentionABC
self.feed_forward: FeedForward
self.attention_norm: nn.LayerNorm
self.ffn_norm: nn.LayerNorm
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
for key in list(state_dict.keys()):
new_key = (
key.replace("self_attn", "attention")
.replace("linear", "feed_forward.linear")
.replace("norm1", "attention_norm")
.replace("norm2", "ffn_norm")
)
state_dict[new_key] = state_dict.pop(key)
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds):
h = self.attention_norm(
x
+ self.attention(
x,
input_pos,
kv_cache,
*args,
**kwds,
)
)
out = self.ffn_norm(h + self.feed_forward(h))
return out
def prefill(
self,
x: Tensor,
kv_cache: KVCacheProtocol,
attn_mask: Tensor,
) -> Tensor:
h = self.attention_norm(
x
+ self.attention.prefill(
x,
kv_cache,
attn_mask,
)
)
out = self.ffn_norm(h + self.feed_forward(h))
return out
class TransformerDecoderABC(nn.Module, ABC):
def __init__(
self,
hidden_dim: int,
n_layer: int,
n_head: int,
ffn_dim: int,
vocab_size: int,
max_seq_length: int,
max_batch_size: int,
) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.n_head = n_head
assert hidden_dim % n_head == 0
self.head_dim = hidden_dim // n_head
self.vocab_size = vocab_size
self.n_layer = n_layer
self.layers: MutableSequence[TransformerBlockABC]
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
def __call__(self, input_pos: Tensor, x: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds):
for layer, kv_cache in zip(self.layers, kv_caches):
x = layer(x, input_pos, kv_cache, *args, **kwds)
return x
def prefill(self, x: Tensor, kv_caches: MutableSequence[KVCacheProtocol], attn_mask: Tensor):
for layer, kv_cache in zip(self.layers, kv_caches):
x = layer.prefill(x, kv_cache, attn_mask)
return x
class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
def __init__(
self,
config: dict,
max_seq_length: int = 1800,
max_batch_size: int = 10,
) -> None:
super().__init__()
hidden_dim: int = config["model"]["hidden_dim"]
embedding_dim: int = config["model"]["embedding_dim"]
n_head: int = config["model"]["head"]
n_layer: int = config["model"]["n_layer"]
vocab_size: int = config["model"]["vocab_size"]
phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
EOS: int = config["model"]["EOS"]
ffn_dim: int = hidden_dim * 4
self.n_layer = int(n_layer)
self.hidden_dim = int(hidden_dim)
self.n_head = int(n_head)
assert hidden_dim % n_head == 0
self.head_dim = int(hidden_dim // n_head)
self.embedding_dim = int(embedding_dim)
self.ffn_dim = int(ffn_dim)
self.vocab_size = int(vocab_size)
self.phoneme_vocab_size = int(phoneme_vocab_size)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
self.EOS = EOS
assert self.EOS == self.vocab_size - 1
self.bert_proj: nn.Linear
self.ar_predict_layer: nn.Linear
self.h: TransformerDecoderABC
self.kv_class: type[KVCacheABC]
self.GraphCache: CUDAGraphCacheABC | None
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim,
scale=False,
alpha=True,
max_batch_size=max_batch_size,
max_seq_len=max_seq_length,
)
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim,
scale=False,
alpha=True,
max_batch_size=max_batch_size,
max_seq_len=max_seq_length,
)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
model_keys = [key for key in state_dict if key.startswith("model.")]
for key in model_keys:
new_key = key[len("model.") :]
state_dict[new_key] = state_dict.pop(key)
def init_cache(self, bsz: int = 0) -> MutableSequence[KVCacheProtocol]:
bsz = bsz or self.h.max_batch_size
assert bsz <= self.h.max_batch_size
seq_lens = self.h.max_seq_length
dtype = self.bert_proj.bias.dtype
kvclass = self.kv_class
return nn.ModuleList(
[kvclass(bsz, seq_lens, self.n_head, self.head_dim) for _ in range(self.n_layer)],
).to(self.device, dtype) # type: ignore
def embed(
self,
x: list[torch.Tensor],
y: torch.Tensor,
bert_features: list[torch.Tensor],
):
x_len: list[int] = [i.shape[0] for i in x]
x_len_max = max(x_len)
xy_pos = torch.zeros((len(x), x_len_max + y.shape[1], self.embedding_dim)).to(bert_features[0].dtype)
bert_features = list(map(lambda x: x.transpose(0, 1), bert_features))
y_len = y.shape[1]
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position.prefill(y_emb)
for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
x_emb = self.ar_text_embedding(x_)
bert = self.bert_proj(bert_feature)
x_emb = x_emb + bert
x_pos = self.ar_text_position.prefill(x_emb.unsqueeze(0))
xy_pos[[bs], :len_] = x_pos
xy_pos[[bs], len_ : len_ + y_len] = y_pos
return xy_pos
def compile(self, *args, **kwds):
# Experimental features to reduce compilation times, will be on by default in future
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True
torch._inductor.config.triton.cudagraph_trees = True
torch._inductor.config.triton.cudagraph_support_input_mutation = True
self.h.compile(fullgraph=True, mode="reduce-overhead")
def capture(
self, input_pos: Tensor, x: Tensor, x_dec: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds
) -> CUDAGraph:
assert torch.cuda.is_available()
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
graph = torch.cuda.CUDAGraph()
with torch.cuda.stream(s):
for _ in range(5):
self.h(input_pos, x, kv_caches, *args, **kwds)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(graph):
x_dec.copy_(self.h(input_pos, x, kv_caches, *args, **kwds))
torch.cuda.synchronize()
return graph
@abstractmethod
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
return list(), dict()
@abstractmethod
def post_forward(self, idx: int, session: T2SSession) -> None:
return
class CUDAGraphCacheABC(ABC):
def __init__(
self,
decoder: T2SDecoderABC,
) -> None:
self.is_applicable: bool
if torch.cuda.is_available() and self.is_applicable:
self.device: torch.device = decoder.device
self.dtype = decoder.bert_proj.bias.dtype
self.assigned: bool = False
self.decoder: T2SDecoderABC = decoder
self.kv_cache: MutableSequence[KVCacheProtocol] = decoder.init_cache(decoder.max_batch_size)
self.xy_pos = torch.rand(size=(decoder.max_batch_size, 1, decoder.embedding_dim), device=self.device).to(
self.dtype
)
self.xy_dec = self.xy_pos.clone()
self.input_pos = torch.tensor([10] * decoder.max_batch_size, device=self.device).int()
self.graph: torch.cuda.CUDAGraph | None = None
self.stream: torch.cuda.Stream | None
self.id: int = random.randint(1, 2**32 - 1)
def assign_graph(self, session: T2SSession):
if self.graph is None:
args, kwds = self.decoder.pre_forward(session)
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
self.graph = graph
self.stream = torch.cuda.Stream()
if self.assigned is False:
self.get_cache_graph(session)
session.id = self.id
self.assigned = True
else:
self.capture_new_graph(session)
@abstractmethod
def release_graph(self, session: T2SSession): ...
@abstractmethod
def get_cache_graph(self, session: T2SSession):
pass
@abstractmethod
def capture_new_graph(self, session: T2SSession):
pass
class TorchProfiler:
def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
self.debug = debug
self.log_dir = log_dir
self.__profiler: torch.profiler.profile
if self.debug and not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.tensorboard_handler = tensorboard_trace_handler(self.log_dir)
def profiler_callback(self, prof: torch.profiler.profile):
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))
self.tensorboard_handler(prof)
@staticmethod
def three_step_schedule(step: int) -> ProfilerAction:
if step == 0:
return ProfilerAction.NONE
elif step == 1:
return ProfilerAction.RECORD
elif step == 2:
return ProfilerAction.RECORD_AND_SAVE
else:
return ProfilerAction.NONE
def start(self):
if not self.debug:
return
assert self.__profiler is not None
self.__profiler.step()
def end(self):
if not self.debug:
return
assert self.__profiler is not None
self.__profiler.step()
def profiler(self):
if self.debug:
activities_list = [torch.profiler.ProfilerActivity.CPU]
if torch.cuda.is_available():
activities_list.append(torch.profiler.ProfilerActivity.CUDA)
self.__profiler = torch.profiler.profile(
activities=activities_list,
record_shapes=True,
with_stack=True,
with_modules=True,
profile_memory=True,
schedule=self.three_step_schedule,
on_trace_ready=self.profiler_callback,
)
return self.__profiler
else:
return nullcontext()
def record(self, func_name: str):
if self.debug:
return torch.profiler.record_function(func_name)
else:
return nullcontext()