mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-07-02 20:10:27 +08:00
219 lines
9.9 KiB
Python
219 lines
9.9 KiB
Python
import contextlib
|
|
import os
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from importlib import import_module
|
|
from typing import Literal
|
|
|
|
import torch
|
|
from rich.progress import BarColumn, Progress, TextColumn
|
|
|
|
from ..logger import SpeedColumnToken, console, logger, timer
|
|
from .structs import T2SEngineProtocol, T2SRequest, T2SResult, T2SSession
|
|
from .t2s_model_abc import (
|
|
CUDAGraphCacheABC,
|
|
T2SDecoderABC,
|
|
TorchProfiler,
|
|
)
|
|
|
|
|
|
class T2SEngine(T2SEngineProtocol):
|
|
def __init__(
|
|
self,
|
|
decoder_model: T2SDecoderABC,
|
|
device: torch.device = torch.device("cpu"),
|
|
dtype: torch.dtype = torch.float32,
|
|
) -> None:
|
|
assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"}
|
|
assert dtype in {torch.float16, torch.bfloat16, torch.float32}
|
|
|
|
self.device = device if device.type != "mps" else torch.device("cpu")
|
|
self.dtype = dtype
|
|
|
|
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
|
|
# self.decoder_model.compile()
|
|
|
|
self.graphcache: CUDAGraphCacheABC = self.init_cache()
|
|
|
|
def _handle_request(self, request: T2SRequest):
|
|
with self.device:
|
|
decoder = self.decoder_model
|
|
session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
|
|
batch_idx = torch.arange(session.bsz)
|
|
debug = request.debug
|
|
|
|
t1 = 0.0
|
|
infer_speed = 0.0
|
|
infer_time = 0.0
|
|
idx = 0
|
|
|
|
torch_profiler = TorchProfiler(debug)
|
|
with (
|
|
torch_profiler.profiler(),
|
|
Progress(
|
|
TextColumn("[cyan]{task.description}"),
|
|
BarColumn(),
|
|
TextColumn("{task.completed}/{task.total} tokens"),
|
|
SpeedColumnToken(show_speed=True),
|
|
console=console,
|
|
transient=True,
|
|
) as progress,
|
|
):
|
|
torch_profiler.start()
|
|
max_token = min(int(1500 - session.input_pos.max()), 1000) * session.bsz
|
|
task = progress.add_task("T2S Decoding", total=max_token)
|
|
|
|
for idx in range(max_token):
|
|
progress.update(task, advance=session.bsz)
|
|
if idx == 0:
|
|
with torch_profiler.record("Prefill"), timer("Torch.Prefill", debug=debug):
|
|
session.kv_cache = decoder.init_cache(session.bsz)
|
|
t1 = time.perf_counter()
|
|
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
|
|
xy_dec = xy_dec[batch_idx, None, session.input_pos - 1]
|
|
else:
|
|
if (
|
|
request.use_cuda_graph
|
|
and session.graph is None
|
|
and self.graphcache.is_applicable
|
|
and torch.cuda.is_available()
|
|
):
|
|
self.graphcache.assign_graph(session)
|
|
|
|
with torch_profiler.record("Decode"), timer("Torch.Decode", debug=debug):
|
|
if session.graph:
|
|
assert session.stream
|
|
session.stream.wait_stream(torch.cuda.default_stream())
|
|
with torch.cuda.stream(session.stream):
|
|
session.xy_pos_.copy_(session.xy_pos)
|
|
session.graph.replay()
|
|
xy_dec = session.xy_dec_.clone()
|
|
else:
|
|
args, kwds = decoder.pre_forward(session)
|
|
xy_dec = decoder.h(
|
|
session.input_pos,
|
|
session.xy_pos,
|
|
session.kv_cache,
|
|
*args,
|
|
**kwds,
|
|
)
|
|
|
|
with torch.cuda.stream(session.stream) if session.stream is not None else contextlib.nullcontext():
|
|
decoder.post_forward(idx, session)
|
|
logits = decoder.ar_predict_layer(xy_dec.squeeze(1))
|
|
|
|
if idx == 0:
|
|
logits[:, -1] = float("-inf")
|
|
|
|
with torch_profiler.record("Sampling"), timer("Torch.Sampling", debug=debug):
|
|
samples = session.sample(
|
|
logits=logits,
|
|
previous_tokens=session.y[:, : session.y_len + idx],
|
|
top_k=request.top_k,
|
|
top_p=request.top_p,
|
|
repetition_penalty=request.repetition_penalty,
|
|
temperature=request.temperature,
|
|
)
|
|
session.y[batch_idx.reshape(-1, 1), session.y_len + idx] = samples
|
|
session.input_pos.add_(1)
|
|
|
|
with torch_profiler.record("EOS"), timer("Torch.EOS", debug=debug):
|
|
argmax_token = torch.argmax(logits, dim=-1)
|
|
sample_token = samples.squeeze(1)
|
|
EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
|
|
|
|
newly_done_mask = EOS_mask & (~session.completed)
|
|
newly_done_indices = newly_done_mask.nonzero()
|
|
|
|
if newly_done_indices.numel() > 0:
|
|
for i in newly_done_indices:
|
|
session.y_results[i] = session.y[i, session.y_len : session.y_len + idx].squeeze(0)
|
|
session.completed[newly_done_indices] = True
|
|
|
|
if torch.all(session.completed).item():
|
|
logger.info(
|
|
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
|
|
)
|
|
logger.info(
|
|
f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s"
|
|
)
|
|
infer_time = time.perf_counter() - t1
|
|
infer_speed = (idx + 1) * session.bsz / infer_time
|
|
break
|
|
|
|
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
|
|
for i in range(session.bsz):
|
|
if not session.completed[i].item():
|
|
session.y_results[i] = session.y[[i], session.y_len : session.y_len + idx].squeeze(
|
|
0
|
|
)
|
|
session.completed[i] = True
|
|
logger.error("Bad Full Prediction")
|
|
infer_time = time.perf_counter() - t1
|
|
infer_speed = (idx + 1) * session.bsz / infer_time
|
|
break
|
|
|
|
with torch_profiler.record("NextPos"), timer("Torch.NextPos", debug=debug):
|
|
y_emb = decoder.ar_audio_embedding(samples)
|
|
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
|
|
|
|
if idx == 10:
|
|
torch_profiler.end()
|
|
|
|
if request.use_cuda_graph and self.graphcache.is_applicable:
|
|
self.graphcache.release_graph(session)
|
|
|
|
return session.y_results[: request.valid_length], infer_speed, infer_time, (idx + 1) * session.bsz
|
|
|
|
def generate(self, request: T2SRequest):
|
|
try:
|
|
result, infer_speed, infer_time, total_tokens = self._handle_request(request)
|
|
t2s_result = T2SResult(
|
|
result=result,
|
|
infer_speed=(infer_speed, infer_time),
|
|
total_tokens=total_tokens,
|
|
status="Success",
|
|
)
|
|
except Exception as e:
|
|
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
|
|
if self.decoder_model.compiled:
|
|
self.decoder_model.save_compile_cache()
|
|
self.compiled = None
|
|
return t2s_result
|
|
|
|
@staticmethod
|
|
def load_decoder(
|
|
weights_path: os.PathLike,
|
|
max_batch_size: int = 1,
|
|
backend: str = "Flash-Attn-Varlen-CUDAGraph",
|
|
quantize_mode: Literal["Int8", "FP8", "FP8_E4M3FN"] | None = None,
|
|
) -> T2SDecoderABC:
|
|
logger.info(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend")
|
|
module_path = f".backends.{backend.lower().replace('-', '_').replace('cudagraph', 'cuda_graph')}"
|
|
decoder_cls_name = "T2SDecoder"
|
|
decoder_mod = import_module(module_path, package=__package__)
|
|
decoder_cls: type[T2SDecoderABC] = getattr(decoder_mod, decoder_cls_name)
|
|
dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
|
|
config = dict_s1["config"]
|
|
decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=max_batch_size)
|
|
state_dict = dict_s1["weight"]
|
|
decoder.load_state_dict(state_dict)
|
|
|
|
if quantize_mode is not None:
|
|
decoder.quantize(quantize_mode)
|
|
logger.info(f"Quantized by {quantize_mode} Quantization")
|
|
|
|
return decoder.eval()
|
|
|
|
def init_cache(self):
|
|
assert self.decoder_model
|
|
|
|
module_name = self.decoder_model.__class__.__module__
|
|
module = sys.modules.get(module_name)
|
|
assert module
|
|
|
|
target_class: type[CUDAGraphCacheABC] = getattr(module, "CUDAGraphCache")
|
|
|
|
return target_class(self.decoder_model)
|