From 08d627c3338173c3229286d8787060d6559fe0f8 Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Thu, 30 Apr 2026 15:01:45 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0cuda=20graph=E6=94=AF?= =?UTF-8?q?=E6=8C=81=EF=BC=8C=E6=99=AE=E9=80=9A=E6=8E=A8=E7=90=86=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=E6=8E=A8=E7=90=86=E9=80=9F=E5=BA=A6=E5=8E=9F=E5=9C=B0?= =?UTF-8?q?=E7=BF=BB=E5=80=8D=EF=BC=8C=E6=95=88=E6=9E=9C=E4=B8=8D=E5=8F=98?= =?UTF-8?q?=E3=80=822?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加cuda graph支持,普通推理模式推理速度原地翻倍,效果不变。2 --- GPT_SoVITS/AR/models/embedding_cudagraph.py | 76 +++ GPT_SoVITS/AR/models/structs_cudagraph.py | 78 +++ GPT_SoVITS/AR/models/t2s_model_cudagraph.py | 602 ++++++++++++++++++++ 3 files changed, 756 insertions(+) create mode 100644 GPT_SoVITS/AR/models/embedding_cudagraph.py create mode 100644 GPT_SoVITS/AR/models/structs_cudagraph.py create mode 100644 GPT_SoVITS/AR/models/t2s_model_cudagraph.py diff --git a/GPT_SoVITS/AR/models/embedding_cudagraph.py b/GPT_SoVITS/AR/models/embedding_cudagraph.py new file mode 100644 index 00000000..b34f0824 --- /dev/null +++ b/GPT_SoVITS/AR/models/embedding_cudagraph.py @@ -0,0 +1,76 @@ +import math + +import torch +from torch import nn + + +class TokenEmbedding(nn.Module): + def __init__(self, embedding_dim: int, vocab_size: int, dropout: float = 0.0): + super().__init__() + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.dropout = nn.Dropout(p=dropout) + self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim) + + @property + def weight(self) -> torch.Tensor: + return self.word_embeddings.weight + + def embedding(self, index: int) -> torch.Tensor: + return self.word_embeddings.weight[index : index + 1] + + def forward(self, x: torch.Tensor): + x = self.word_embeddings(x) + x = self.dropout(x) + return x + + +class SinePositionalEmbeddingNested(nn.Module): + def __init__( + self, + embedding_dim: int, + dropout: float = 0.0, + scale: bool = False, + alpha: bool = False, + max_batch_size: int = 20, + max_seq_len: int = 2500, + ): + super().__init__() + self.embedding_dim = embedding_dim + self.x_scale = math.sqrt(embedding_dim) if scale else 1.0 + self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) + self.dropout = nn.Dropout(p=dropout) + self.max_batch_size = max_batch_size + self.max_seq_len = max_seq_len + + self.reverse = False + self.register_buffer( + "pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False + ) + self.pe: torch.Tensor + self.compute_pe() + + def compute_pe(self): + if self.reverse: + position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1) + else: + position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.embedding_dim) + ) + pe = self.pe + pe[:, :, 0::2] = torch.sin(position * div_term) + pe[:, :, 1::2] = torch.cos(position * div_term) + + def forward(self, input_pos: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + batch_size = x.shape[0] + pe_values = self.pe[torch.arange(batch_size), input_pos - 1] + return x * self.x_scale + self.alpha * pe_values.unsqueeze(1) + + def prefill(self, x: torch.Tensor) -> torch.Tensor: + input_pos = torch.tensor([i.shape[0] for i in x.unbind()]) + pe_values = torch.nested.nested_tensor( + [self.pe[i, : input_pos[i], :] for i in range(input_pos.size(0))] + ) + return x * self.x_scale + self.alpha.item() * pe_values diff --git a/GPT_SoVITS/AR/models/structs_cudagraph.py b/GPT_SoVITS/AR/models/structs_cudagraph.py new file mode 100644 index 00000000..52fd11c3 --- /dev/null +++ b/GPT_SoVITS/AR/models/structs_cudagraph.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Literal, Optional + +import torch + +Tensor = torch.Tensor + + +@dataclass +class T2SResult: + result: List[Tensor] | None = None + infer_speed: float = 0.0 + status: Literal["Success", "Error"] = "Success" + exception: Optional[Exception] = None + traceback: Optional[str] = None + + +@dataclass +class T2SRequest: + x: List[torch.Tensor] + x_lens: Tensor + prompts: torch.Tensor + bert_feature: List[Tensor] + valid_length: int + top_k: int = 5 + top_p: float = 1 + early_stop_num: int = -1 + temperature: float = 1.0 + repetition_penalty: float = 1.35 + use_cuda_graph: bool = False + debug: bool = False + + +class T2SSession: + def __init__(self, decoder, request: T2SRequest, device: torch.device, dtype: torch.dtype): + with device: + self.decoder = decoder + self.request = request + self.device = device + self.dtype = dtype + + bsz = len(request.x) + y_len = request.prompts.size(-1) + self.bsz = bsz + self.y_len = y_len + + from AR.models.t2s_model_cudagraph import Sampler + + self.sampler = Sampler(bsz, decoder.vocab_size) + + self.x = request.x + self.x_lens = request.x_lens.to(torch.int32) + self.y = request.prompts + self.bert_feature = request.bert_feature + + self.prefill_len = self.x_lens + self.y.size(1) + + self.input_pos = torch.zeros_like(self.prefill_len) + self.input_pos.add_(self.prefill_len) + + self.completed = torch.Tensor([False] * len(self.x)).bool().to(device) + self.y_results: List[Tensor] = [None] * len(self.x) # type: ignore + + self.xy_pos = decoder.embed(self.x, self.y, self.bert_feature) + + attn_mask = [] + for bs in range(bsz): + pos = int(self.x_lens[bs].item()) + mask = torch.zeros(pos + y_len, pos + y_len).bool() + mask[:, :pos].fill_(True) + if y_len > 0: + mask[-y_len:, -y_len:] = ~torch.triu( + torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1 + ) + attn_mask.append(mask) + self.attn_mask_nested = torch.nested.nested_tensor(attn_mask) diff --git a/GPT_SoVITS/AR/models/t2s_model_cudagraph.py b/GPT_SoVITS/AR/models/t2s_model_cudagraph.py new file mode 100644 index 00000000..0dcb1093 --- /dev/null +++ b/GPT_SoVITS/AR/models/t2s_model_cudagraph.py @@ -0,0 +1,602 @@ +""" +CUDA Graph accelerated T2S decoder. +Uses PyTorch native scaled_dot_product_attention (no flash_attn dependency). +Adapted from gsvpp/AR/models/t2s_model_abc.py and t2s_model_flash_attn.py. +""" + +from __future__ import annotations + +import os +import time +import traceback +from typing import Dict, List, MutableSequence, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.graphs import CUDAGraph +from tqdm import tqdm + +from AR.models.embedding_cudagraph import ( + SinePositionalEmbeddingNested as SinePositionalEmbedding, +) +from AR.models.embedding_cudagraph import TokenEmbedding +from AR.models.structs_cudagraph import T2SRequest, T2SResult, T2SSession + +Tensor = torch.Tensor + + +class Sampler(nn.Module): + def __init__(self, batch_size: int, vocab_size: int) -> None: + super().__init__() + self.batch_size = batch_size + + def sample( + self, + logits: Tensor, + previous_tokens: Tensor, + temperature: float, + top_k: int, + top_p: float, + repetition_penalty: float, + ) -> Tensor: + previous_tokens = previous_tokens.long() + score = torch.gather(logits, dim=1, index=previous_tokens) + score = torch.where( + score < 0, score * repetition_penalty, score / repetition_penalty + ) + logits.scatter_(dim=1, index=previous_tokens, src=score) + + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum( + torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1 + ) + sorted_indices_to_remove = cum_probs > top_p + sorted_indices_to_remove[:, 0] = False + indices_to_remove = sorted_indices_to_remove.scatter( + dim=1, index=sorted_indices, src=sorted_indices_to_remove + ) + logits = logits.masked_fill(indices_to_remove, -float("Inf")) + + logits = logits / max(temperature, 1e-5) + + v, _ = torch.topk(logits, top_k) + pivot = v[:, -1].unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + + probs = torch.nn.functional.softmax(logits, dim=-1) + q = torch.empty_like(probs).exponential_(1.0) + idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32) + + return idx_next + + +# ─── KV Cache ────────────────────���─────────────────────────────────────────── + + +class KVCacheNHD(nn.Module): + def __init__(self, batch_size, max_seq_length, n_heads, head_dim): + super().__init__() + assert batch_size > 0 + cache_shape = (batch_size, max_seq_length, n_heads, head_dim) + self.n_head = n_heads + self.head_dim = head_dim + self.batch_size = batch_size + self.max_seq_length = max_seq_length + self.register_buffer( + "k_cache", torch.zeros(size=cache_shape), persistent=False + ) + self.register_buffer( + "v_cache", torch.zeros(size=cache_shape), persistent=False + ) + + def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor): + index = ( + (input_pos - 1) + .unsqueeze(-1) + .unsqueeze(-1) + .unsqueeze(-1) + .expand(-1, -1, self.n_head, self.head_dim) + .to(torch.int64) + ) + k_out = self.k_cache + v_out = self.v_cache + k_out.scatter_(1, index, k_val) + v_out.scatter_(1, index, v_val) + return k_out, v_out + + def empty(self): + self.k_cache.zero_() + self.v_cache.zero_() + + def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int): + self.k_cache[[bs], : k_val.shape[1]] = k_val + self.v_cache[[bs], : v_val.shape[1]] = v_val + + +# ─── Attention (PyTorch native SDPA, no flash_attn) ───────────────────────── + + +class Attention(nn.Module): + def __init__(self, n_head: int, hidden_dim: int): + super().__init__() + self.n_head = n_head + self.hidden_dim = hidden_dim + assert hidden_dim % n_head == 0 + self.head_dim = hidden_dim // n_head + self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True) + self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True) + self.dropout = nn.Dropout(0.1) + + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict: dict, prefix, *args): + keys_to_modify = [key for key in state_dict if "in_proj_" in key] + for key in keys_to_modify: + new_key = key.replace("in_proj_", "in_proj.") + state_dict[new_key] = state_dict.pop(key) + + def forward( + self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheNHD + ) -> Tensor: + bsz, seqlen, _ = x.shape + + q, k, v = self.in_proj.forward(x).chunk(3, dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_head, self.head_dim) + v = v.view(bsz, seqlen, self.n_head, self.head_dim) + + k_cache, v_cache = kv_cache.update(input_pos, k, v) + + q = q.transpose(1, 2) # [B, H, 1, D] + k_out = k_cache.transpose(1, 2) # [B, H, max_seq, D] + v_out = v_cache.transpose(1, 2) # [B, H, max_seq, D] + + attn = F.scaled_dot_product_attention(q, k_out, v_out) + + attn = self.dropout.forward(attn) + attn = attn.transpose(1, 2).reshape(bsz, seqlen, self.hidden_dim) + attn = self.out_proj.forward(attn) + return attn + + def prefill(self, x: Tensor, mask: Tensor, kv_cache: KVCacheNHD) -> Tensor: + bsz = x.size(0) + outputs = [] + for bs in range(bsz): + x_b = x[bs].unsqueeze(0) + q, k, v = self.in_proj.forward(x_b.unsqueeze(0)).chunk(3, dim=-1) + q = q.contiguous().view(1, -1, self.n_head, self.head_dim) + k = k.contiguous().view(1, -1, self.n_head, self.head_dim) + v = v.contiguous().view(1, -1, self.n_head, self.head_dim) + kv_cache.prefill_kv(k, v, bs) + q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) + attn_mask = ( + mask[bs].unsqueeze(0).unsqueeze(0).expand(1, self.n_head, -1, -1) + ) + attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + attn = self.dropout.forward(attn) + attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim) + output = self.out_proj.forward(attn) + outputs.append(output.squeeze(0)) + return torch.nested.nested_tensor(outputs) + + +# ─── Feed Forward ──────────────────────────────────────────────────────────── + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int) -> None: + super().__init__() + self.linear1 = nn.Linear(dim, hidden_dim, bias=True) + self.linear2 = nn.Linear(hidden_dim, dim, bias=True) + self.dropout = nn.Dropout(0.1) + + def forward(self, x: Tensor) -> Tensor: + return self.dropout.forward( + self.linear2(self.dropout.forward(F.relu(self.linear1(x)))) + ) + + +# ─── Transformer Block ────────────────────────────────────────────────────── + + +class TransformerBlock(nn.Module): + def __init__(self, n_head, ffn_dim, hidden_dim) -> None: + super().__init__() + self.hidden_dim = hidden_dim + self.attention = Attention(n_head, hidden_dim) + self.feed_forward = FeedForward(hidden_dim, ffn_dim) + self.attention_norm = nn.LayerNorm([hidden_dim]) + self.ffn_norm = nn.LayerNorm([hidden_dim]) + self.dropout = nn.Dropout(0.1) + + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict: dict[str, Tensor], prefix, *args): + for key in list(state_dict.keys()): + new_key = ( + key.replace("self_attn", "attention") + .replace("linear", "feed_forward.linear") + .replace("norm1", "attention_norm") + .replace("norm2", "ffn_norm") + ) + state_dict[new_key] = state_dict.pop(key) + + def forward( + self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheNHD + ) -> Tensor: + h = self.attention_norm.forward( + x + self.dropout.forward(self.attention.forward(x, input_pos, kv_cache)) + ) + out = self.ffn_norm.forward(h + self.feed_forward.forward(h)) + return out + + def prefill(self, x: Tensor, mask: Tensor, kv_cache: KVCacheNHD) -> Tensor: + h = self.attention_norm.forward( + x + self.dropout.forward(self.attention.prefill(x, mask, kv_cache)) + ) + out = self.ffn_norm.forward(h + self.feed_forward.forward(h)) + return out + + +# ─── Transformer Decoder ──────────────────────────────────────────────────── + + +class TransformerDecoder(nn.Module): + def __init__( + self, + hidden_dim, + n_layer, + n_head, + ffn_dim, + vocab_size, + max_seq_length, + max_batch_size, + ) -> None: + super().__init__() + self.hidden_dim = hidden_dim + self.n_head = n_head + assert hidden_dim % n_head == 0 + self.head_dim = hidden_dim // n_head + self.vocab_size = vocab_size + self.n_layer = n_layer + self.layers = nn.ModuleList( + TransformerBlock(n_head, ffn_dim, hidden_dim) for _ in range(n_layer) + ) + self.max_seq_length: int = max_seq_length + self.max_batch_size: int = max_batch_size + + def forward( + self, + input_pos: Tensor, + x: Tensor, + kv_caches: MutableSequence[KVCacheNHD], + ): + for layer, kv_cache in zip(self.layers, kv_caches): + x = layer.forward(x, input_pos, kv_cache) + return x + + def prefill( + self, + x: Tensor, + mask: Tensor, + kv_caches: MutableSequence[KVCacheNHD], + ): + for layer, kv_cache in zip(self.layers, kv_caches): + x = layer.prefill(x, mask, kv_cache) + return x + + +# ─── T2S Decoder ───────────────────────────────────────────────────────────── + + +class T2SDecoder(nn.Module): + def __init__( + self, + config, + *args, + norm_first=False, + max_seq_length=2500, + max_batch_size=10, + **kwds, + ) -> None: + super().__init__() + hidden_dim = config["model"]["hidden_dim"] + embedding_dim = config["model"]["embedding_dim"] + n_head = config["model"]["head"] + n_layer = config["model"]["n_layer"] + vocab_size = config["model"]["vocab_size"] + phoneme_vocab_size = config["model"]["phoneme_vocab_size"] + p_dropout = config["model"]["dropout"] + EOS = config["model"]["EOS"] + ffn_dim = hidden_dim * 4 + + self.n_layer = n_layer + self.hidden_dim = hidden_dim + self.n_head = n_head + assert hidden_dim % n_head == 0 + self.head_dim = hidden_dim // n_head + self.embedding_dim = embedding_dim + self.vocab_size = vocab_size + self.phoneme_vocab_size = phoneme_vocab_size + self.p_dropout = p_dropout + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + self.EOS = EOS + assert self.EOS == self.vocab_size - 1 + + self.bert_proj = nn.Linear(1024, self.embedding_dim) + self.ar_text_embedding = TokenEmbedding( + self.embedding_dim, self.phoneme_vocab_size, self.p_dropout + ) + self.ar_text_position = SinePositionalEmbedding( + self.embedding_dim, + dropout=0.1, + scale=False, + alpha=True, + max_batch_size=max_batch_size, + max_seq_len=max_seq_length, + ) + self.ar_audio_embedding = TokenEmbedding( + self.embedding_dim, self.vocab_size, self.p_dropout + ) + self.ar_audio_position = SinePositionalEmbedding( + self.embedding_dim, + dropout=0.1, + scale=False, + alpha=True, + max_batch_size=max_batch_size, + max_seq_len=max_seq_length, + ) + self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False) + self.h = TransformerDecoder( + hidden_dim, + n_layer, + n_head, + ffn_dim, + vocab_size, + max_seq_length, + max_batch_size, + ) + + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + model_keys = [key for key in state_dict if key.startswith("model.")] + for key in model_keys: + new_key = key[len("model.") :] + state_dict[new_key] = state_dict.pop(key) + + def init_cache(self, bsz: int = 0) -> nn.ModuleList: + bsz = bsz or self.h.max_batch_size + assert bsz <= self.h.max_batch_size + seq_lens = self.h.max_seq_length + device = self.bert_proj.bias.device + dtype = self.bert_proj.bias.dtype + return nn.ModuleList( + [ + KVCacheNHD(bsz, seq_lens, self.n_head, self.head_dim) + for _ in range(self.n_layer) + ], + ).to(device, dtype) + + def embed( + self, + x: List[torch.Tensor], + y: torch.Tensor, + bert_features: List[torch.Tensor], + ): + x_nested = torch.nested.nested_tensor(x) + assert x_nested.size(0) <= self.max_batch_size + bert_features_nested = torch.nested.nested_tensor( + list(map(lambda t: t.transpose(0, 1), bert_features)) + ) + x_emb = self.ar_text_embedding.forward(x_nested) + bert = self.bert_proj.forward(bert_features_nested) + x_emb = x_emb + bert + x_pos = self.ar_text_position.prefill(x_emb) + + y_nested = torch.nested.nested_tensor(list(y.unbind(0))) + y_emb = self.ar_audio_embedding.forward(y_nested) + y_pos = self.ar_audio_position.prefill(y_emb) + + xy_pos = torch.nested.nested_tensor( + [torch.cat([x_pos[i], y_pos[i]]) for i in range(len(x))] + ) + return xy_pos + + def capture( + self, + input_pos: Tensor, + x: Tensor, + x_dec: Tensor, + kv_caches, + ) -> CUDAGraph: + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + + graph = torch.cuda.CUDAGraph() + + with torch.cuda.stream(s): + for _ in range(5): + self.h.forward(input_pos, x, kv_caches) + torch.cuda.current_stream().wait_stream(s) + + with torch.cuda.graph(graph): + x_dec.copy_(self.h.forward(input_pos, x, kv_caches)) + torch.cuda.synchronize() + + return graph + + +# ─── CUDA Graph Runner ─────────────────────────────────────────────────────── + + +class CUDAGraphRunner: + def __init__( + self, + decoder_model: T2SDecoder, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, + ) -> None: + assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"} + assert dtype in {torch.float16, torch.bfloat16, torch.float32} + self.device = device + self.dtype = dtype + self.decoder_model: T2SDecoder = decoder_model.to(self.device, self.dtype) + self.graph: Optional[CUDAGraph] = None + self.xy_pos_ = torch.rand( + (1, 1, decoder_model.embedding_dim), device=device + ).to(dtype) + self.xy_dec_ = torch.rand( + (1, 1, decoder_model.embedding_dim), device=device + ).to(dtype) + self.kv_cache = decoder_model.init_cache(1) + self.input_pos = torch.tensor([10]).int().cuda() + + def _handle_request(self, request: T2SRequest): + with self.device: + for i in self.kv_cache: + i.empty() + + decoder = self.decoder_model + session = T2SSession(decoder, request, device=self.device, dtype=self.dtype) + self.input_pos.copy_(session.input_pos) + + t1 = 0.0 + infer_speed = 0.0 + y = session.y + bsz = y.size(0) + + for idx in tqdm(range(1500)): + if idx == 0: + xy_dec = decoder.h.prefill( + session.xy_pos, session.attn_mask_nested, self.kv_cache + ) + xy_dec = torch.stack([t[[-1]] for t in xy_dec.unbind()]) + else: + if ( + request.use_cuda_graph + and self.graph is None + and torch.cuda.is_available() + ): + self.xy_pos_.copy_(session.xy_pos) + self.graph = decoder.capture( + self.input_pos, + self.xy_pos_, + self.xy_dec_, + kv_caches=self.kv_cache, + ) + + if self.graph: + self.xy_pos_.copy_(session.xy_pos) + self.graph.replay() + xy_dec = self.xy_dec_.clone() + else: + xy_dec = decoder.h.forward( + self.input_pos, + session.xy_pos, + self.kv_cache, + ) + + logits = decoder.ar_predict_layer(xy_dec[:, -1]) + self.input_pos.add_(1) + + if idx == 0: + logits[:, -1] = float("-inf") + + samples = session.sampler.sample( + logits=logits, + previous_tokens=session.y, + top_k=request.top_k, + top_p=request.top_p, + repetition_penalty=request.repetition_penalty, + temperature=request.temperature, + ) + + session.y = torch.cat([session.y, samples], dim=1) + + argmax_token = torch.argmax(logits, dim=-1) + sample_token = samples.squeeze(1) + EOS_mask = (argmax_token == decoder.EOS) | ( + sample_token == decoder.EOS + ) + + newly_done_mask = EOS_mask & (~session.completed) + newly_done_indices = newly_done_mask.nonzero() + + if newly_done_indices.numel() > 0: + session.y_results[newly_done_indices[0]] = session.y[ + newly_done_indices[0], session.y_len : -1 + ].squeeze(0) + session.completed[newly_done_indices] = True + + if torch.all(session.completed).item(): + if session.y.size(1) == 0: + session.y = torch.cat( + [session.y, torch.zeros_like(samples)], dim=1 + ) + tqdm.write("Bad Zero Prediction") + else: + tqdm.write( + f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n" + f"{[i.size(0) for i in session.y_results].__str__().strip('[]')}" + ) + tqdm.write( + f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s" + ) + infer_speed = (idx - 1) / (time.perf_counter() - t1) + break + + if ( + request.early_stop_num != -1 + and (session.y.size(1) - session.y_len) > request.early_stop_num + ) or idx == 1499: + for i in range(bsz): + if not session.completed[i].item(): + session.y_results[i] = session.y[i, session.y_len :] + session.completed[i] = True + break + + y_emb = decoder.ar_audio_embedding(session.y[:, -1:]) + session.xy_pos = decoder.ar_audio_position.forward( + self.input_pos - session.x_lens, y_emb + ) + + if idx == 2: + t1 = time.perf_counter() + + if idx % 100 == 0 and self.device.type == "cuda": + torch.cuda.empty_cache() + + if self.device.type == "cuda": + torch.cuda.empty_cache() + + return session.y_results[: request.valid_length], infer_speed + + def generate(self, request: T2SRequest) -> T2SResult: + try: + result, infer_speed = self._handle_request(request) + t2s_result = T2SResult( + result=result, infer_speed=infer_speed, status="Success" + ) + except Exception as e: + t2s_result = T2SResult( + status="Error", exception=e, traceback=traceback.format_exc() + ) + return t2s_result + + @staticmethod + def load_decoder(weights_path, max_batch_size=1) -> T2SDecoder: + print( + f"Loading Text2Semantic Weights from {weights_path} with CUDA Graph (SDPA) Implement" + ) + dict_s1 = torch.load( + weights_path, map_location="cpu", weights_only=False#, mmap=True + ) + config = dict_s1["config"] + decoder = T2SDecoder(config, max_batch_size=max_batch_size) + state_dict = dict_s1["weight"] + decoder.load_state_dict(state_dict) + return decoder.eval()