from typing import NoReturn import torch from torch.nn import functional as F from .. import nn from ..structs import KVCacheProtocol, T2SSession from ..t2s_model_abc import ( AttentionABC, CUDAGraphCacheABC, FeedForward, KVCacheHNDVarlen, T2SDecoderABC, TransformerBlockABC, TransformerDecoderABC, ) Tensor = torch.Tensor class Attention(AttentionABC): def __init__(self, n_head, hidden_dim, max_seq_length): super().__init__(n_head, hidden_dim, max_seq_length) # key, query, value projections for all heads, but in a batch self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True) self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True) def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor): bsz, seqlen, _ = x.shape q, k, v = self.in_proj(x).chunk(3, dim=-1) q = q.view(bsz, seqlen, self.n_head, self.head_dim) k = k.view(bsz, seqlen, self.n_head, self.head_dim) v = v.view(bsz, seqlen, self.n_head, self.head_dim) q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) k, v = kv_cache.update(input_pos, k, v) max_idx = input_pos.max() q, k, v = map(lambda x: x[..., :max_idx, :], (q, k, v)) mask = attn_mask[..., :max_idx] attn = F.scaled_dot_product_attention(q, k, v, mask) attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim) attn = self.out_proj(attn) return attn class TransformerBlock(TransformerBlockABC): def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None: super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length) self.attention = Attention(n_head, hidden_dim, max_seq_length) self.feed_forward = FeedForward(hidden_dim, ffn_dim) self.attention_norm = nn.LayerNorm([self.hidden_dim]) self.ffn_norm = nn.LayerNorm([self.hidden_dim]) class TransformerDecoder(TransformerDecoderABC): def __init__( self, hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size, ) -> None: super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size) self.layers = nn.ModuleList( # type: ignore TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer) ) class T2SDecoder(T2SDecoderABC): def __init__( self, config, max_seq_length=2000, max_batch_size=10, ) -> None: super().__init__(config, max_seq_length, max_batch_size) self.bert_proj = nn.Linear(1024, self.embedding_dim) self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False) self.h: TransformerDecoderABC = TransformerDecoder( self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size ) self.kv_class = KVCacheHNDVarlen def capture( self, *args, **kwds, ) -> NoReturn: raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model") def pre_forward(self, session: T2SSession): attn_mask = session.attn_mask return list(), dict(attn_mask=attn_mask) def post_forward(self, idx: int, session: T2SSession) -> None: if idx == 0: prefill_len = session.prefill_len bsz = session.bsz range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length) prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1) attn_mask = range_tensor < prefill_len_expanded attn_mask = attn_mask.expand(-1, self.n_head, -1, -1) session.attn_mask = attn_mask attn_mask = session.attn_mask input_pos = session.input_pos attn_mask[torch.arange(session.bsz), :, :, input_pos] = True class CUDAGraphCache(CUDAGraphCacheABC): def __init__( self, decoder, ) -> None: self.is_applicable = False super().__init__(decoder) def release_graph(self, session: T2SSession): raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model") def get_cache_graph(self, session: T2SSession): raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model") def capture_new_graph(self, session: T2SSession): raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")