mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 17:10:02 +08:00
224 lines
10 KiB
Python
224 lines
10 KiB
Python
import contextlib
|
|
import gc
|
|
import os
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from importlib import import_module
|
|
|
|
import torch
|
|
from rich.progress import BarColumn, Progress, TextColumn
|
|
|
|
from ..logger import SpeedColumnToken, console, logger
|
|
from .structs import T2SEngineProtocol, T2SRequest, T2SResult, T2SSession
|
|
from .t2s_model_abc import (
|
|
CUDAGraphCacheABC,
|
|
T2SDecoderABC,
|
|
TorchProfiler,
|
|
)
|
|
|
|
|
|
class T2SEngine(T2SEngineProtocol):
|
|
def __init__(
|
|
self,
|
|
decoder_model: T2SDecoderABC,
|
|
device: torch.device = torch.device("cpu"),
|
|
dtype: torch.dtype = torch.float32,
|
|
) -> None:
|
|
assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"}
|
|
assert dtype in {torch.float16, torch.bfloat16, torch.float32}
|
|
|
|
self.device = device
|
|
self.dtype = dtype
|
|
|
|
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
|
|
|
|
self.graphcache: CUDAGraphCacheABC = self.init_cache()
|
|
|
|
def _handle_request(self, request: T2SRequest):
|
|
with self.device:
|
|
decoder = self.decoder_model
|
|
session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
|
|
batch_idx = torch.arange(session.bsz)
|
|
|
|
t1 = 0.0
|
|
infer_speed = 0.0
|
|
infer_time = 0.0
|
|
|
|
torch_profiler = TorchProfiler(request.debug)
|
|
with (
|
|
torch_profiler.profiler(),
|
|
Progress(
|
|
TextColumn("[cyan]{task.description}"),
|
|
BarColumn(),
|
|
TextColumn("{task.completed}/{task.total} tokens"),
|
|
SpeedColumnToken(show_speed=True),
|
|
console=console,
|
|
transient=True,
|
|
) as progress,
|
|
):
|
|
max_token = int(min(1800 - session.input_pos.max(), 1500))
|
|
task = progress.add_task("T2S Decoding", total=max_token)
|
|
|
|
for idx in range(max_token):
|
|
progress.update(task, advance=1)
|
|
if idx == 0:
|
|
session.kv_cache = decoder.init_cache(session.bsz)
|
|
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
|
|
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
|
else:
|
|
if (
|
|
request.use_cuda_graph
|
|
and session.graph is None
|
|
and self.graphcache.is_applicable
|
|
and torch.cuda.is_available()
|
|
):
|
|
self.graphcache.assign_graph(session)
|
|
|
|
with torch_profiler.record("AR"):
|
|
if session.graph:
|
|
assert session.stream
|
|
session.stream.wait_stream(torch.cuda.default_stream())
|
|
with torch.cuda.stream(session.stream):
|
|
session.xy_pos_.copy_(session.xy_pos)
|
|
session.graph.replay()
|
|
xy_dec = session.xy_dec_.clone()
|
|
else:
|
|
args, kwds = decoder.pre_forward(session)
|
|
xy_dec = decoder.h(
|
|
session.input_pos,
|
|
session.xy_pos,
|
|
session.kv_cache,
|
|
*args,
|
|
**kwds,
|
|
)
|
|
|
|
with torch.cuda.stream(session.stream) if session.stream is not None else contextlib.nullcontext():
|
|
decoder.post_forward(idx, session)
|
|
logits = decoder.ar_predict_layer(xy_dec[:, -1])
|
|
|
|
if idx == 0:
|
|
logits[:, -1] = float("-inf")
|
|
|
|
with torch_profiler.record("Sampling"):
|
|
samples = session.sample(
|
|
logits=logits,
|
|
previous_tokens=session.y[:, : session.y_len + idx],
|
|
top_k=request.top_k,
|
|
top_p=request.top_p,
|
|
repetition_penalty=request.repetition_penalty,
|
|
temperature=request.temperature,
|
|
)
|
|
session.y[batch_idx, session.y_len + idx] = samples
|
|
session.input_pos.add_(1)
|
|
|
|
with torch_profiler.record("EOS"):
|
|
argmax_token = torch.argmax(logits, dim=-1)
|
|
sample_token = samples.squeeze(1)
|
|
EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
|
|
|
|
newly_done_mask = EOS_mask & (~session.completed)
|
|
newly_done_indices = newly_done_mask.nonzero()
|
|
|
|
if newly_done_indices.numel() > 0:
|
|
for i in newly_done_indices:
|
|
session.y_results[i] = session.y[i, session.y_len : session.y_len + idx]
|
|
session.completed[newly_done_indices] = True
|
|
|
|
if torch.all(session.completed).item():
|
|
if session.y[:, session.y_len :].sum() == 0:
|
|
session.y_results = [torch.tensor(0) for _ in range(session.bsz)]
|
|
logger.error("Bad Zero Prediction")
|
|
else:
|
|
logger.info(
|
|
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
|
|
)
|
|
logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
|
|
infer_time = time.perf_counter() - t1
|
|
infer_speed = (idx - 1) / infer_time
|
|
break
|
|
|
|
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
|
|
for i in range(session.bsz):
|
|
if not session.completed[i].item():
|
|
session.y_results[i] = session.y[i, session.y_len : session.y_len + 1499]
|
|
session.completed[i] = True
|
|
logger.error("Bad Full Prediction")
|
|
break
|
|
|
|
with torch_profiler.record("NextPos"):
|
|
y_emb = decoder.ar_audio_embedding(samples)
|
|
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
|
|
|
|
if idx == 1:
|
|
torch_profiler.start()
|
|
t1 = time.perf_counter()
|
|
|
|
if idx == 51:
|
|
torch_profiler.end()
|
|
|
|
if idx % 100 == 0:
|
|
match session.device.type:
|
|
case "cuda":
|
|
torch.cuda.empty_cache()
|
|
case "mps":
|
|
torch.mps.empty_cache()
|
|
case "xpu":
|
|
torch.xpu.empty_cache()
|
|
case "mtia":
|
|
torch.mtia.empty_cache()
|
|
|
|
match session.device.type:
|
|
case "cuda":
|
|
if session.stream is not None:
|
|
torch.cuda.current_stream().wait_stream(session.stream)
|
|
torch.cuda.empty_cache()
|
|
case "mps":
|
|
torch.mps.empty_cache()
|
|
case "xpu":
|
|
torch.xpu.empty_cache()
|
|
case "mtia":
|
|
torch.mtia.empty_cache()
|
|
case "cpu":
|
|
gc.collect()
|
|
|
|
torch_profiler.end()
|
|
if request.use_cuda_graph and torch.cuda.is_available():
|
|
self.graphcache.release_graph(session)
|
|
|
|
return session.y_results[: request.valid_length], infer_speed, infer_time
|
|
|
|
def generate(self, request: T2SRequest):
|
|
try:
|
|
result, infer_speed, infer_time = self._handle_request(request)
|
|
t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success")
|
|
except Exception as e:
|
|
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
|
|
return t2s_result
|
|
|
|
@staticmethod
|
|
def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "Flash-Attn-Varlen-CUDAGraph"):
|
|
logger.info(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend")
|
|
module_path = f".backends.{backend.lower().replace('-', '_').replace('cudagraph', 'cuda_graph')}"
|
|
decoder_cls_name = "T2SDecoder"
|
|
decoder_mod = import_module(module_path, package=__package__)
|
|
decoder_cls: type[T2SDecoderABC] = getattr(decoder_mod, decoder_cls_name)
|
|
dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
|
|
config = dict_s1["config"]
|
|
decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=max_batch_size)
|
|
state_dict = dict_s1["weight"]
|
|
decoder.load_state_dict(state_dict)
|
|
|
|
return decoder.eval()
|
|
|
|
def init_cache(self):
|
|
assert self.decoder_model
|
|
|
|
module_name = self.decoder_model.__class__.__module__
|
|
module = sys.modules.get(module_name)
|
|
assert module
|
|
|
|
target_class: type[CUDAGraphCacheABC] = getattr(module, "CUDAGraphCache")
|
|
|
|
return target_class(self.decoder_model)
|