GPT-SoVITS/GPT_SoVITS/AR/models/structs_cudagraph.py
RVC-Boss 08d627c333
增加cuda graph支持,普通推理模式推理速度原地翻倍,效果不变。2
增加cuda graph支持,普通推理模式推理速度原地翻倍,效果不变。2
2026-04-30 15:01:45 +08:00

79 lines
2.3 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import List, Literal, Optional
import torch
Tensor = torch.Tensor
@dataclass
class T2SResult:
result: List[Tensor] | None = None
infer_speed: float = 0.0
status: Literal["Success", "Error"] = "Success"
exception: Optional[Exception] = None
traceback: Optional[str] = None
@dataclass
class T2SRequest:
x: List[torch.Tensor]
x_lens: Tensor
prompts: torch.Tensor
bert_feature: List[Tensor]
valid_length: int
top_k: int = 5
top_p: float = 1
early_stop_num: int = -1
temperature: float = 1.0
repetition_penalty: float = 1.35
use_cuda_graph: bool = False
debug: bool = False
class T2SSession:
def __init__(self, decoder, request: T2SRequest, device: torch.device, dtype: torch.dtype):
with device:
self.decoder = decoder
self.request = request
self.device = device
self.dtype = dtype
bsz = len(request.x)
y_len = request.prompts.size(-1)
self.bsz = bsz
self.y_len = y_len
from AR.models.t2s_model_cudagraph import Sampler
self.sampler = Sampler(bsz, decoder.vocab_size)
self.x = request.x
self.x_lens = request.x_lens.to(torch.int32)
self.y = request.prompts
self.bert_feature = request.bert_feature
self.prefill_len = self.x_lens + self.y.size(1)
self.input_pos = torch.zeros_like(self.prefill_len)
self.input_pos.add_(self.prefill_len)
self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
self.y_results: List[Tensor] = [None] * len(self.x) # type: ignore
self.xy_pos = decoder.embed(self.x, self.y, self.bert_feature)
attn_mask = []
for bs in range(bsz):
pos = int(self.x_lens[bs].item())
mask = torch.zeros(pos + y_len, pos + y_len).bool()
mask[:, :pos].fill_(True)
if y_len > 0:
mask[-y_len:, -y_len:] = ~torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
)
attn_mask.append(mask)
self.attn_mask_nested = torch.nested.nested_tensor(attn_mask)