mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-05 13:58:18 +08:00
603 lines
22 KiB
Python
603 lines
22 KiB
Python
"""
|
||
CUDA Graph accelerated T2S decoder.
|
||
Uses PyTorch native scaled_dot_product_attention (no flash_attn dependency).
|
||
Adapted from gsvpp/AR/models/t2s_model_abc.py and t2s_model_flash_attn.py.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
import time
|
||
import traceback
|
||
from typing import Dict, List, MutableSequence, Optional, Tuple
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from torch.cuda.graphs import CUDAGraph
|
||
from tqdm import tqdm
|
||
|
||
from AR.models.embedding_cudagraph import (
|
||
SinePositionalEmbeddingNested as SinePositionalEmbedding,
|
||
)
|
||
from AR.models.embedding_cudagraph import TokenEmbedding
|
||
from AR.models.structs_cudagraph import T2SRequest, T2SResult, T2SSession
|
||
|
||
Tensor = torch.Tensor
|
||
|
||
|
||
class Sampler(nn.Module):
|
||
def __init__(self, batch_size: int, vocab_size: int) -> None:
|
||
super().__init__()
|
||
self.batch_size = batch_size
|
||
|
||
def sample(
|
||
self,
|
||
logits: Tensor,
|
||
previous_tokens: Tensor,
|
||
temperature: float,
|
||
top_k: int,
|
||
top_p: float,
|
||
repetition_penalty: float,
|
||
) -> Tensor:
|
||
previous_tokens = previous_tokens.long()
|
||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||
score = torch.where(
|
||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||
)
|
||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||
|
||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||
cum_probs = torch.cumsum(
|
||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||
)
|
||
sorted_indices_to_remove = cum_probs > top_p
|
||
sorted_indices_to_remove[:, 0] = False
|
||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||
)
|
||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||
|
||
logits = logits / max(temperature, 1e-5)
|
||
|
||
v, _ = torch.topk(logits, top_k)
|
||
pivot = v[:, -1].unsqueeze(-1)
|
||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||
|
||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||
q = torch.empty_like(probs).exponential_(1.0)
|
||
idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
|
||
|
||
return idx_next
|
||
|
||
|
||
# ─── KV Cache ────────────────────<E29480><E29480><EFBFBD>───────────────────────────────────────────
|
||
|
||
|
||
class KVCacheNHD(nn.Module):
|
||
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
|
||
super().__init__()
|
||
assert batch_size > 0
|
||
cache_shape = (batch_size, max_seq_length, n_heads, head_dim)
|
||
self.n_head = n_heads
|
||
self.head_dim = head_dim
|
||
self.batch_size = batch_size
|
||
self.max_seq_length = max_seq_length
|
||
self.register_buffer(
|
||
"k_cache", torch.zeros(size=cache_shape), persistent=False
|
||
)
|
||
self.register_buffer(
|
||
"v_cache", torch.zeros(size=cache_shape), persistent=False
|
||
)
|
||
|
||
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
||
index = (
|
||
(input_pos - 1)
|
||
.unsqueeze(-1)
|
||
.unsqueeze(-1)
|
||
.unsqueeze(-1)
|
||
.expand(-1, -1, self.n_head, self.head_dim)
|
||
.to(torch.int64)
|
||
)
|
||
k_out = self.k_cache
|
||
v_out = self.v_cache
|
||
k_out.scatter_(1, index, k_val)
|
||
v_out.scatter_(1, index, v_val)
|
||
return k_out, v_out
|
||
|
||
def empty(self):
|
||
self.k_cache.zero_()
|
||
self.v_cache.zero_()
|
||
|
||
def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int):
|
||
self.k_cache[[bs], : k_val.shape[1]] = k_val
|
||
self.v_cache[[bs], : v_val.shape[1]] = v_val
|
||
|
||
|
||
# ─── Attention (PyTorch native SDPA, no flash_attn) ─────────────────────────
|
||
|
||
|
||
class Attention(nn.Module):
|
||
def __init__(self, n_head: int, hidden_dim: int):
|
||
super().__init__()
|
||
self.n_head = n_head
|
||
self.hidden_dim = hidden_dim
|
||
assert hidden_dim % n_head == 0
|
||
self.head_dim = hidden_dim // n_head
|
||
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
||
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||
self.dropout = nn.Dropout(0.1)
|
||
|
||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||
|
||
def load_hook(self, state_dict: dict, prefix, *args):
|
||
keys_to_modify = [key for key in state_dict if "in_proj_" in key]
|
||
for key in keys_to_modify:
|
||
new_key = key.replace("in_proj_", "in_proj.")
|
||
state_dict[new_key] = state_dict.pop(key)
|
||
|
||
def forward(
|
||
self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheNHD
|
||
) -> Tensor:
|
||
bsz, seqlen, _ = x.shape
|
||
|
||
q, k, v = self.in_proj.forward(x).chunk(3, dim=-1)
|
||
|
||
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
||
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
||
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
||
|
||
k_cache, v_cache = kv_cache.update(input_pos, k, v)
|
||
|
||
q = q.transpose(1, 2) # [B, H, 1, D]
|
||
k_out = k_cache.transpose(1, 2) # [B, H, max_seq, D]
|
||
v_out = v_cache.transpose(1, 2) # [B, H, max_seq, D]
|
||
|
||
attn = F.scaled_dot_product_attention(q, k_out, v_out)
|
||
|
||
attn = self.dropout.forward(attn)
|
||
attn = attn.transpose(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
||
attn = self.out_proj.forward(attn)
|
||
return attn
|
||
|
||
def prefill(self, x: Tensor, mask: Tensor, kv_cache: KVCacheNHD) -> Tensor:
|
||
bsz = x.size(0)
|
||
outputs = []
|
||
for bs in range(bsz):
|
||
x_b = x[bs].unsqueeze(0)
|
||
q, k, v = self.in_proj.forward(x_b.unsqueeze(0)).chunk(3, dim=-1)
|
||
q = q.contiguous().view(1, -1, self.n_head, self.head_dim)
|
||
k = k.contiguous().view(1, -1, self.n_head, self.head_dim)
|
||
v = v.contiguous().view(1, -1, self.n_head, self.head_dim)
|
||
kv_cache.prefill_kv(k, v, bs)
|
||
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
|
||
attn_mask = (
|
||
mask[bs].unsqueeze(0).unsqueeze(0).expand(1, self.n_head, -1, -1)
|
||
)
|
||
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||
attn = self.dropout.forward(attn)
|
||
attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim)
|
||
output = self.out_proj.forward(attn)
|
||
outputs.append(output.squeeze(0))
|
||
return torch.nested.nested_tensor(outputs)
|
||
|
||
|
||
# ─── Feed Forward ────────────────────────────────────────────────────────────
|
||
|
||
|
||
class FeedForward(nn.Module):
|
||
def __init__(self, dim: int, hidden_dim: int) -> None:
|
||
super().__init__()
|
||
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
|
||
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
|
||
self.dropout = nn.Dropout(0.1)
|
||
|
||
def forward(self, x: Tensor) -> Tensor:
|
||
return self.dropout.forward(
|
||
self.linear2(self.dropout.forward(F.relu(self.linear1(x))))
|
||
)
|
||
|
||
|
||
# ─── Transformer Block ──────────────────────────────────────────────────────
|
||
|
||
|
||
class TransformerBlock(nn.Module):
|
||
def __init__(self, n_head, ffn_dim, hidden_dim) -> None:
|
||
super().__init__()
|
||
self.hidden_dim = hidden_dim
|
||
self.attention = Attention(n_head, hidden_dim)
|
||
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
||
self.attention_norm = nn.LayerNorm([hidden_dim])
|
||
self.ffn_norm = nn.LayerNorm([hidden_dim])
|
||
self.dropout = nn.Dropout(0.1)
|
||
|
||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||
|
||
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
|
||
for key in list(state_dict.keys()):
|
||
new_key = (
|
||
key.replace("self_attn", "attention")
|
||
.replace("linear", "feed_forward.linear")
|
||
.replace("norm1", "attention_norm")
|
||
.replace("norm2", "ffn_norm")
|
||
)
|
||
state_dict[new_key] = state_dict.pop(key)
|
||
|
||
def forward(
|
||
self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheNHD
|
||
) -> Tensor:
|
||
h = self.attention_norm.forward(
|
||
x + self.dropout.forward(self.attention.forward(x, input_pos, kv_cache))
|
||
)
|
||
out = self.ffn_norm.forward(h + self.feed_forward.forward(h))
|
||
return out
|
||
|
||
def prefill(self, x: Tensor, mask: Tensor, kv_cache: KVCacheNHD) -> Tensor:
|
||
h = self.attention_norm.forward(
|
||
x + self.dropout.forward(self.attention.prefill(x, mask, kv_cache))
|
||
)
|
||
out = self.ffn_norm.forward(h + self.feed_forward.forward(h))
|
||
return out
|
||
|
||
|
||
# ─── Transformer Decoder ────────────────────────────────────────────────────
|
||
|
||
|
||
class TransformerDecoder(nn.Module):
|
||
def __init__(
|
||
self,
|
||
hidden_dim,
|
||
n_layer,
|
||
n_head,
|
||
ffn_dim,
|
||
vocab_size,
|
||
max_seq_length,
|
||
max_batch_size,
|
||
) -> None:
|
||
super().__init__()
|
||
self.hidden_dim = hidden_dim
|
||
self.n_head = n_head
|
||
assert hidden_dim % n_head == 0
|
||
self.head_dim = hidden_dim // n_head
|
||
self.vocab_size = vocab_size
|
||
self.n_layer = n_layer
|
||
self.layers = nn.ModuleList(
|
||
TransformerBlock(n_head, ffn_dim, hidden_dim) for _ in range(n_layer)
|
||
)
|
||
self.max_seq_length: int = max_seq_length
|
||
self.max_batch_size: int = max_batch_size
|
||
|
||
def forward(
|
||
self,
|
||
input_pos: Tensor,
|
||
x: Tensor,
|
||
kv_caches: MutableSequence[KVCacheNHD],
|
||
):
|
||
for layer, kv_cache in zip(self.layers, kv_caches):
|
||
x = layer.forward(x, input_pos, kv_cache)
|
||
return x
|
||
|
||
def prefill(
|
||
self,
|
||
x: Tensor,
|
||
mask: Tensor,
|
||
kv_caches: MutableSequence[KVCacheNHD],
|
||
):
|
||
for layer, kv_cache in zip(self.layers, kv_caches):
|
||
x = layer.prefill(x, mask, kv_cache)
|
||
return x
|
||
|
||
|
||
# ─── T2S Decoder ─────────────────────────────────────────────────────────────
|
||
|
||
|
||
class T2SDecoder(nn.Module):
|
||
def __init__(
|
||
self,
|
||
config,
|
||
*args,
|
||
norm_first=False,
|
||
max_seq_length=2500,
|
||
max_batch_size=10,
|
||
**kwds,
|
||
) -> None:
|
||
super().__init__()
|
||
hidden_dim = config["model"]["hidden_dim"]
|
||
embedding_dim = config["model"]["embedding_dim"]
|
||
n_head = config["model"]["head"]
|
||
n_layer = config["model"]["n_layer"]
|
||
vocab_size = config["model"]["vocab_size"]
|
||
phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
|
||
p_dropout = config["model"]["dropout"]
|
||
EOS = config["model"]["EOS"]
|
||
ffn_dim = hidden_dim * 4
|
||
|
||
self.n_layer = n_layer
|
||
self.hidden_dim = hidden_dim
|
||
self.n_head = n_head
|
||
assert hidden_dim % n_head == 0
|
||
self.head_dim = hidden_dim // n_head
|
||
self.embedding_dim = embedding_dim
|
||
self.vocab_size = vocab_size
|
||
self.phoneme_vocab_size = phoneme_vocab_size
|
||
self.p_dropout = p_dropout
|
||
self.max_seq_length = max_seq_length
|
||
self.max_batch_size = max_batch_size
|
||
self.EOS = EOS
|
||
assert self.EOS == self.vocab_size - 1
|
||
|
||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||
self.ar_text_embedding = TokenEmbedding(
|
||
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
|
||
)
|
||
self.ar_text_position = SinePositionalEmbedding(
|
||
self.embedding_dim,
|
||
dropout=0.1,
|
||
scale=False,
|
||
alpha=True,
|
||
max_batch_size=max_batch_size,
|
||
max_seq_len=max_seq_length,
|
||
)
|
||
self.ar_audio_embedding = TokenEmbedding(
|
||
self.embedding_dim, self.vocab_size, self.p_dropout
|
||
)
|
||
self.ar_audio_position = SinePositionalEmbedding(
|
||
self.embedding_dim,
|
||
dropout=0.1,
|
||
scale=False,
|
||
alpha=True,
|
||
max_batch_size=max_batch_size,
|
||
max_seq_len=max_seq_length,
|
||
)
|
||
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
||
self.h = TransformerDecoder(
|
||
hidden_dim,
|
||
n_layer,
|
||
n_head,
|
||
ffn_dim,
|
||
vocab_size,
|
||
max_seq_length,
|
||
max_batch_size,
|
||
)
|
||
|
||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||
|
||
def load_hook(self, state_dict, prefix, *args):
|
||
model_keys = [key for key in state_dict if key.startswith("model.")]
|
||
for key in model_keys:
|
||
new_key = key[len("model.") :]
|
||
state_dict[new_key] = state_dict.pop(key)
|
||
|
||
def init_cache(self, bsz: int = 0) -> nn.ModuleList:
|
||
bsz = bsz or self.h.max_batch_size
|
||
assert bsz <= self.h.max_batch_size
|
||
seq_lens = self.h.max_seq_length
|
||
device = self.bert_proj.bias.device
|
||
dtype = self.bert_proj.bias.dtype
|
||
return nn.ModuleList(
|
||
[
|
||
KVCacheNHD(bsz, seq_lens, self.n_head, self.head_dim)
|
||
for _ in range(self.n_layer)
|
||
],
|
||
).to(device, dtype)
|
||
|
||
def embed(
|
||
self,
|
||
x: List[torch.Tensor],
|
||
y: torch.Tensor,
|
||
bert_features: List[torch.Tensor],
|
||
):
|
||
x_nested = torch.nested.nested_tensor(x)
|
||
assert x_nested.size(0) <= self.max_batch_size
|
||
bert_features_nested = torch.nested.nested_tensor(
|
||
list(map(lambda t: t.transpose(0, 1), bert_features))
|
||
)
|
||
x_emb = self.ar_text_embedding.forward(x_nested)
|
||
bert = self.bert_proj.forward(bert_features_nested)
|
||
x_emb = x_emb + bert
|
||
x_pos = self.ar_text_position.prefill(x_emb)
|
||
|
||
y_nested = torch.nested.nested_tensor(list(y.unbind(0)))
|
||
y_emb = self.ar_audio_embedding.forward(y_nested)
|
||
y_pos = self.ar_audio_position.prefill(y_emb)
|
||
|
||
xy_pos = torch.nested.nested_tensor(
|
||
[torch.cat([x_pos[i], y_pos[i]]) for i in range(len(x))]
|
||
)
|
||
return xy_pos
|
||
|
||
def capture(
|
||
self,
|
||
input_pos: Tensor,
|
||
x: Tensor,
|
||
x_dec: Tensor,
|
||
kv_caches,
|
||
) -> CUDAGraph:
|
||
s = torch.cuda.Stream()
|
||
s.wait_stream(torch.cuda.current_stream())
|
||
|
||
graph = torch.cuda.CUDAGraph()
|
||
|
||
with torch.cuda.stream(s):
|
||
for _ in range(5):
|
||
self.h.forward(input_pos, x, kv_caches)
|
||
torch.cuda.current_stream().wait_stream(s)
|
||
|
||
with torch.cuda.graph(graph):
|
||
x_dec.copy_(self.h.forward(input_pos, x, kv_caches))
|
||
torch.cuda.synchronize()
|
||
|
||
return graph
|
||
|
||
|
||
# ─── CUDA Graph Runner ───────────────────────────────────────────────────────
|
||
|
||
|
||
class CUDAGraphRunner:
|
||
def __init__(
|
||
self,
|
||
decoder_model: T2SDecoder,
|
||
device: torch.device = torch.device("cpu"),
|
||
dtype: torch.dtype = torch.float32,
|
||
) -> None:
|
||
assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"}
|
||
assert dtype in {torch.float16, torch.bfloat16, torch.float32}
|
||
self.device = device
|
||
self.dtype = dtype
|
||
self.decoder_model: T2SDecoder = decoder_model.to(self.device, self.dtype)
|
||
self.graph: Optional[CUDAGraph] = None
|
||
self.xy_pos_ = torch.rand(
|
||
(1, 1, decoder_model.embedding_dim), device=device
|
||
).to(dtype)
|
||
self.xy_dec_ = torch.rand(
|
||
(1, 1, decoder_model.embedding_dim), device=device
|
||
).to(dtype)
|
||
self.kv_cache = decoder_model.init_cache(1)
|
||
self.input_pos = torch.tensor([10]).int().cuda()
|
||
|
||
def _handle_request(self, request: T2SRequest):
|
||
with self.device:
|
||
for i in self.kv_cache:
|
||
i.empty()
|
||
|
||
decoder = self.decoder_model
|
||
session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
|
||
self.input_pos.copy_(session.input_pos)
|
||
|
||
t1 = 0.0
|
||
infer_speed = 0.0
|
||
y = session.y
|
||
bsz = y.size(0)
|
||
|
||
for idx in tqdm(range(1500)):
|
||
if idx == 0:
|
||
xy_dec = decoder.h.prefill(
|
||
session.xy_pos, session.attn_mask_nested, self.kv_cache
|
||
)
|
||
xy_dec = torch.stack([t[[-1]] for t in xy_dec.unbind()])
|
||
else:
|
||
if (
|
||
request.use_cuda_graph
|
||
and self.graph is None
|
||
and torch.cuda.is_available()
|
||
):
|
||
self.xy_pos_.copy_(session.xy_pos)
|
||
self.graph = decoder.capture(
|
||
self.input_pos,
|
||
self.xy_pos_,
|
||
self.xy_dec_,
|
||
kv_caches=self.kv_cache,
|
||
)
|
||
|
||
if self.graph:
|
||
self.xy_pos_.copy_(session.xy_pos)
|
||
self.graph.replay()
|
||
xy_dec = self.xy_dec_.clone()
|
||
else:
|
||
xy_dec = decoder.h.forward(
|
||
self.input_pos,
|
||
session.xy_pos,
|
||
self.kv_cache,
|
||
)
|
||
|
||
logits = decoder.ar_predict_layer(xy_dec[:, -1])
|
||
self.input_pos.add_(1)
|
||
|
||
if idx == 0:
|
||
logits[:, -1] = float("-inf")
|
||
|
||
samples = session.sampler.sample(
|
||
logits=logits,
|
||
previous_tokens=session.y,
|
||
top_k=request.top_k,
|
||
top_p=request.top_p,
|
||
repetition_penalty=request.repetition_penalty,
|
||
temperature=request.temperature,
|
||
)
|
||
|
||
session.y = torch.cat([session.y, samples], dim=1)
|
||
|
||
argmax_token = torch.argmax(logits, dim=-1)
|
||
sample_token = samples.squeeze(1)
|
||
EOS_mask = (argmax_token == decoder.EOS) | (
|
||
sample_token == decoder.EOS
|
||
)
|
||
|
||
newly_done_mask = EOS_mask & (~session.completed)
|
||
newly_done_indices = newly_done_mask.nonzero()
|
||
|
||
if newly_done_indices.numel() > 0:
|
||
session.y_results[newly_done_indices[0]] = session.y[
|
||
newly_done_indices[0], session.y_len : -1
|
||
].squeeze(0)
|
||
session.completed[newly_done_indices] = True
|
||
|
||
if torch.all(session.completed).item():
|
||
if session.y.size(1) == 0:
|
||
session.y = torch.cat(
|
||
[session.y, torch.zeros_like(samples)], dim=1
|
||
)
|
||
tqdm.write("Bad Zero Prediction")
|
||
else:
|
||
tqdm.write(
|
||
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n"
|
||
f"{[i.size(0) for i in session.y_results].__str__().strip('[]')}"
|
||
)
|
||
tqdm.write(
|
||
f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s"
|
||
)
|
||
infer_speed = (idx - 1) / (time.perf_counter() - t1)
|
||
break
|
||
|
||
if (
|
||
request.early_stop_num != -1
|
||
and (session.y.size(1) - session.y_len) > request.early_stop_num
|
||
) or idx == 1499:
|
||
for i in range(bsz):
|
||
if not session.completed[i].item():
|
||
session.y_results[i] = session.y[i, session.y_len :]
|
||
session.completed[i] = True
|
||
break
|
||
|
||
y_emb = decoder.ar_audio_embedding(session.y[:, -1:])
|
||
session.xy_pos = decoder.ar_audio_position.forward(
|
||
self.input_pos - session.x_lens, y_emb
|
||
)
|
||
|
||
if idx == 2:
|
||
t1 = time.perf_counter()
|
||
|
||
if idx % 100 == 0 and self.device.type == "cuda":
|
||
torch.cuda.empty_cache()
|
||
|
||
if self.device.type == "cuda":
|
||
torch.cuda.empty_cache()
|
||
|
||
return session.y_results[: request.valid_length], infer_speed
|
||
|
||
def generate(self, request: T2SRequest) -> T2SResult:
|
||
try:
|
||
result, infer_speed = self._handle_request(request)
|
||
t2s_result = T2SResult(
|
||
result=result, infer_speed=infer_speed, status="Success"
|
||
)
|
||
except Exception as e:
|
||
t2s_result = T2SResult(
|
||
status="Error", exception=e, traceback=traceback.format_exc()
|
||
)
|
||
return t2s_result
|
||
|
||
@staticmethod
|
||
def load_decoder(weights_path, max_batch_size=1) -> T2SDecoder:
|
||
print(
|
||
f"Loading Text2Semantic Weights from {weights_path} with CUDA Graph (SDPA) Implement"
|
||
)
|
||
dict_s1 = torch.load(
|
||
weights_path, map_location="cpu", weights_only=False#, mmap=True
|
||
)
|
||
config = dict_s1["config"]
|
||
decoder = T2SDecoder(config, max_batch_size=max_batch_size)
|
||
state_dict = dict_s1["weight"]
|
||
decoder.load_state_dict(state_dict)
|
||
return decoder.eval()
|