mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-07-01 02:59:25 +08:00
.
This commit is contained in:
parent
0ec73aac3c
commit
5da25d4f18
@ -1,4 +1,5 @@
|
||||
import importlib.util
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
@ -25,8 +26,12 @@ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
cpu_count = os.cpu_count() or 1
|
||||
torch.set_num_threads(cpu_count)
|
||||
torch.set_num_interop_threads(cpu_count)
|
||||
|
||||
backends = ["torch_varlen"]
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.is_available() and torch.version.cuda is not None:
|
||||
backends.append("torch_static_cuda_graph")
|
||||
# if importlib.util.find_spec("sageattention") is not None:
|
||||
# for i in range(torch.cuda.device_count()):
|
||||
@ -44,7 +49,7 @@ if torch.cuda.is_available():
|
||||
# backends.append("mps_flash_attn_varlen")
|
||||
|
||||
BLACKWELL = False
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.is_available() and torch.version.cuda is not None:
|
||||
for i in range(torch.cuda.device_count()):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
sm_version = major + minor / 10.0
|
||||
|
||||
@ -12,6 +12,7 @@ from ..structs import T2SSession
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
CUDAGraphCacheABC,
|
||||
CUDAGraphStateABC,
|
||||
FeedForward,
|
||||
KVCacheNHD,
|
||||
KVCacheProtocol,
|
||||
@ -114,6 +115,8 @@ class T2SDecoder(T2SDecoderABC):
|
||||
|
||||
self.kv_class = KVCacheNHD
|
||||
|
||||
self.graph_cache_class = CUDAGraphCache
|
||||
|
||||
def compile(self, *args, **kwds):
|
||||
pass
|
||||
|
||||
@ -124,40 +127,39 @@ class T2SDecoder(T2SDecoderABC):
|
||||
return super().pre_forward(session)
|
||||
|
||||
|
||||
class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
class CUDAGraphState(CUDAGraphStateABC):
|
||||
applicable: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder: T2SDecoder,
|
||||
bsz: int,
|
||||
decoder: T2SDecoderABC,
|
||||
) -> None:
|
||||
self.is_applicable = True
|
||||
super().__init__(decoder)
|
||||
super().__init__(bsz, decoder)
|
||||
|
||||
def release_graph(self, session: T2SSession):
|
||||
if session.id == self.id:
|
||||
self.assigned = False
|
||||
else:
|
||||
assert session.graph
|
||||
session.graph.reset()
|
||||
del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache
|
||||
def capture(self):
|
||||
graph = self.decoder.capture(
|
||||
self.input_pos,
|
||||
self.xy_pos,
|
||||
self.xy_dec,
|
||||
self.kv_cache,
|
||||
)
|
||||
self.graph = graph
|
||||
self.stream = torch.cuda.Stream()
|
||||
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
assert self.graph
|
||||
session.graph = self.graph
|
||||
session.stream = self.stream
|
||||
|
||||
session.xy_pos_ = self.xy_pos
|
||||
session.xy_dec_ = self.xy_dec
|
||||
session.input_pos = self.input_pos.copy_(session.input_pos)
|
||||
class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
is_applicable = True
|
||||
|
||||
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
|
||||
cache.sync_cache(cache_)
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
cache_size: int = 5,
|
||||
) -> None:
|
||||
super().__init__(decoder, cache_size)
|
||||
|
||||
def capture_new_graph(self, session: T2SSession):
|
||||
session.xy_pos_ = self.xy_pos.clone()
|
||||
session.xy_dec_ = self.xy_dec.clone()
|
||||
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
||||
|
||||
args, kwds = self.decoder.pre_forward(session)
|
||||
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
||||
session.graph = graph
|
||||
session.stream = torch.cuda.Stream() # type: ignore
|
||||
def create_graph_cache(self, bsz: int):
|
||||
for _ in range(self.cache_size):
|
||||
state = CUDAGraphState(bsz, self.decoder)
|
||||
state.capture()
|
||||
self.graph_cache[bsz].put(state)
|
||||
|
||||
@ -6,6 +6,7 @@ from ..structs import KVCacheProtocol, T2SSession
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
CUDAGraphCacheABC,
|
||||
CUDAGraphStateABC,
|
||||
FeedForward,
|
||||
KVCacheHND,
|
||||
T2SDecoderABC,
|
||||
@ -91,6 +92,8 @@ class T2SDecoder(T2SDecoderABC):
|
||||
|
||||
self.kv_class = KVCacheHND
|
||||
|
||||
self.graph_cache_class = CUDAGraphCache
|
||||
|
||||
def pre_forward(self, session: T2SSession):
|
||||
attn_mask = session.attn_mask
|
||||
return list(), dict(attn_mask=attn_mask)
|
||||
@ -111,57 +114,54 @@ class T2SDecoder(T2SDecoderABC):
|
||||
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
|
||||
|
||||
|
||||
class CUDAGraphState(CUDAGraphStateABC):
|
||||
applicable: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bsz: int,
|
||||
decoder: T2SDecoderABC,
|
||||
) -> None:
|
||||
self.attn_mask: Tensor = (
|
||||
torch.randint(
|
||||
0,
|
||||
2,
|
||||
(bsz, decoder.n_head, 1, decoder.max_seq_length),
|
||||
)
|
||||
.bool()
|
||||
.to(decoder.device)
|
||||
)
|
||||
|
||||
super().__init__(bsz, decoder)
|
||||
|
||||
def capture(self):
|
||||
graph = self.decoder.capture(
|
||||
self.input_pos,
|
||||
self.xy_pos,
|
||||
self.xy_dec,
|
||||
self.kv_cache,
|
||||
attn_mask=self.attn_mask,
|
||||
)
|
||||
self.graph = graph
|
||||
self.stream = torch.cuda.Stream()
|
||||
|
||||
def assign_graph(self, session: T2SSession):
|
||||
session.attn_mask = self.attn_mask
|
||||
return super().assign_graph(session)
|
||||
|
||||
|
||||
class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
is_applicable = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
cache_size: int = 5,
|
||||
) -> None:
|
||||
self.is_applicable = True
|
||||
super().__init__(decoder)
|
||||
if torch.cuda.is_available():
|
||||
self.attn_mask = (
|
||||
torch.randint(0, 2, (decoder.max_batch_size, decoder.n_head, 1, decoder.max_seq_length))
|
||||
.bool()
|
||||
.to(self.device, self.dtype)
|
||||
)
|
||||
super().__init__(decoder, cache_size)
|
||||
|
||||
def release_graph(self, session: T2SSession):
|
||||
if session.id == self.id:
|
||||
self.assigned = False
|
||||
else:
|
||||
assert session.graph
|
||||
session.graph.reset()
|
||||
del (
|
||||
session.graph,
|
||||
session.xy_pos_,
|
||||
session.xy_dec_,
|
||||
session.input_pos,
|
||||
session.kv_cache,
|
||||
session.attn_mask,
|
||||
)
|
||||
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
assert self.graph
|
||||
session.graph = self.graph
|
||||
session.stream = self.stream
|
||||
|
||||
session.xy_pos_ = self.xy_pos
|
||||
session.xy_dec_ = self.xy_dec
|
||||
session.input_pos = self.input_pos.copy_(session.input_pos)
|
||||
|
||||
session.attn_mask = self.attn_mask
|
||||
|
||||
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
|
||||
cache.sync_cache(cache_)
|
||||
|
||||
def capture_new_graph(self, session: T2SSession):
|
||||
session.xy_pos_ = self.xy_pos.clone()
|
||||
session.xy_dec_ = self.xy_dec.clone()
|
||||
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
||||
|
||||
session.attn_mask = self.attn_mask.clone().copy_(session.attn_mask)
|
||||
|
||||
args, kwds = self.decoder.pre_forward(session)
|
||||
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
||||
session.graph = graph
|
||||
session.stream = torch.cuda.Stream() # type: ignore
|
||||
def create_graph_cache(self, bsz: int):
|
||||
for _ in range(self.cache_size):
|
||||
state = CUDAGraphState(bsz, self.decoder)
|
||||
state.capture()
|
||||
self.graph_cache[bsz].put(state)
|
||||
|
||||
@ -99,6 +99,8 @@ class T2SDecoder(T2SDecoderABC):
|
||||
|
||||
self.kv_class = KVCacheHNDVarlen
|
||||
|
||||
self.graph_cache_class = CUDAGraphCache
|
||||
|
||||
def capture(
|
||||
self,
|
||||
*args,
|
||||
@ -127,18 +129,13 @@ class T2SDecoder(T2SDecoderABC):
|
||||
|
||||
|
||||
class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
is_applicable = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
) -> None:
|
||||
self.is_applicable = False
|
||||
super().__init__(decoder)
|
||||
|
||||
def release_graph(self, session: T2SSession):
|
||||
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
|
||||
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
|
||||
|
||||
def capture_new_graph(self, session: T2SSession):
|
||||
def create_graph_cache(self, bsz: int) -> NoReturn:
|
||||
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from importlib import import_module
|
||||
@ -34,7 +34,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
|
||||
# self.decoder_model.compile()
|
||||
|
||||
self.graphcache: CUDAGraphCacheABC = self.init_cache()
|
||||
self.graphcache: CUDAGraphCacheABC = decoder_model.graph_cache_class(self.decoder_model)
|
||||
|
||||
def _handle_request(self, request: T2SRequest):
|
||||
with self.device:
|
||||
@ -47,6 +47,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
infer_speed = 0.0
|
||||
infer_time = 0.0
|
||||
idx = 0
|
||||
graph_state = None
|
||||
|
||||
torch_profiler = TorchProfiler(debug)
|
||||
with (
|
||||
@ -64,105 +65,132 @@ class T2SEngine(T2SEngineProtocol):
|
||||
max_token = min(int(1500 - session.input_pos.max()), 1000) * session.bsz
|
||||
task = progress.add_task("T2S Decoding", total=max_token)
|
||||
|
||||
for idx in range(max_token):
|
||||
progress.update(task, advance=session.bsz)
|
||||
if idx == 0:
|
||||
with torch_profiler.record("Prefill"), timer("Torch.Prefill", debug=debug):
|
||||
session.kv_cache = decoder.init_cache(session.bsz)
|
||||
t1 = time.perf_counter()
|
||||
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
|
||||
xy_dec = xy_dec[batch_idx, None, session.input_pos - 1]
|
||||
else:
|
||||
if (
|
||||
request.use_cuda_graph
|
||||
and session.graph is None
|
||||
and self.graphcache.is_applicable
|
||||
and torch.cuda.is_available()
|
||||
):
|
||||
self.graphcache.assign_graph(session)
|
||||
|
||||
with torch_profiler.record("Decode"), timer("Torch.Decode", debug=debug):
|
||||
if session.graph:
|
||||
assert session.stream
|
||||
session.stream.wait_stream(torch.cuda.default_stream())
|
||||
with torch.cuda.stream(session.stream):
|
||||
session.xy_pos_.copy_(session.xy_pos)
|
||||
session.graph.replay()
|
||||
xy_dec = session.xy_dec_.clone()
|
||||
else:
|
||||
args, kwds = decoder.pre_forward(session)
|
||||
xy_dec = decoder.h(
|
||||
session.input_pos,
|
||||
session.xy_pos,
|
||||
session.kv_cache,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
|
||||
with torch.cuda.stream(session.stream) if session.stream is not None else contextlib.nullcontext():
|
||||
decoder.post_forward(idx, session)
|
||||
logits = decoder.ar_predict_layer(xy_dec.squeeze(1))
|
||||
|
||||
try:
|
||||
for idx in range(max_token):
|
||||
progress.update(task, advance=session.bsz)
|
||||
if idx == 0:
|
||||
logits[:, -1] = float("-inf")
|
||||
with torch_profiler.record("Prefill"), timer("Torch.Prefill", debug=debug):
|
||||
session.kv_cache = decoder.init_cache(session.bsz)
|
||||
t1 = time.perf_counter()
|
||||
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
|
||||
xy_dec = xy_dec[batch_idx, None, session.input_pos - 1]
|
||||
else:
|
||||
if (
|
||||
idx == 1
|
||||
and request.use_cuda_graph
|
||||
and self.graphcache.is_applicable
|
||||
and torch.cuda.is_available()
|
||||
and torch.version.cuda is not None
|
||||
and os.environ.get("CUDAGraph", "1") != "0"
|
||||
):
|
||||
graph_state = self.graphcache[session.bsz].assign_graph(session)
|
||||
|
||||
with torch_profiler.record("Sampling"), timer("Torch.Sampling", debug=debug):
|
||||
samples = session.sample(
|
||||
logits=logits,
|
||||
previous_tokens=session.y[:, : session.y_len + idx],
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
)
|
||||
session.y[batch_idx.reshape(-1, 1), session.y_len + idx] = samples
|
||||
session.input_pos.add_(1)
|
||||
|
||||
with torch_profiler.record("EOS"), timer("Torch.EOS", debug=debug):
|
||||
argmax_token = torch.argmax(logits, dim=-1)
|
||||
sample_token = samples.squeeze(1)
|
||||
EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
|
||||
|
||||
newly_done_mask = EOS_mask & (~session.completed)
|
||||
newly_done_indices = newly_done_mask.nonzero()
|
||||
|
||||
if newly_done_indices.numel() > 0:
|
||||
for i in newly_done_indices:
|
||||
session.y_results[i] = session.y[i, session.y_len : session.y_len + idx].squeeze(0)
|
||||
session.completed[newly_done_indices] = True
|
||||
|
||||
if torch.all(session.completed).item():
|
||||
logger.info(
|
||||
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
|
||||
)
|
||||
logger.info(
|
||||
f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s"
|
||||
)
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx + 1) * session.bsz / infer_time
|
||||
break
|
||||
|
||||
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
|
||||
for i in range(session.bsz):
|
||||
if not session.completed[i].item():
|
||||
session.y_results[i] = session.y[[i], session.y_len : session.y_len + idx].squeeze(
|
||||
0
|
||||
with torch_profiler.record("Decode"), timer("Torch.Decode", debug=debug):
|
||||
if session.graph:
|
||||
assert session.stream
|
||||
session.stream.wait_stream(torch.cuda.default_stream())
|
||||
with torch.cuda.stream(session.stream):
|
||||
session.xy_pos_.copy_(session.xy_pos)
|
||||
session.graph.replay()
|
||||
xy_dec = session.xy_dec_.clone()
|
||||
else:
|
||||
args, kwds = decoder.pre_forward(session)
|
||||
xy_dec = decoder.h(
|
||||
session.input_pos,
|
||||
session.xy_pos,
|
||||
session.kv_cache,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
session.completed[i] = True
|
||||
logger.error("Bad Full Prediction")
|
||||
|
||||
with (
|
||||
torch.cuda.stream(session.stream)
|
||||
if session.stream is not None
|
||||
else contextlib.nullcontext()
|
||||
):
|
||||
decoder.post_forward(idx, session)
|
||||
logits = decoder.ar_predict_layer(xy_dec.squeeze(1))
|
||||
|
||||
if idx == 0:
|
||||
logits[:, -1] = float("-inf")
|
||||
|
||||
with torch_profiler.record("Sampling"), timer("Torch.Sampling", debug=debug):
|
||||
samples = session.sample(
|
||||
logits=logits,
|
||||
previous_tokens=session.y[:, : session.y_len + idx],
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
)
|
||||
session.y[batch_idx.reshape(-1, 1), session.y_len + idx] = samples
|
||||
session.input_pos.add_(1)
|
||||
|
||||
with torch_profiler.record("EOS"), timer("Torch.EOS", debug=debug):
|
||||
argmax_token = torch.argmax(logits, dim=-1)
|
||||
sample_token = samples.squeeze(1)
|
||||
EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
|
||||
|
||||
newly_done_mask = EOS_mask & (~session.completed)
|
||||
newly_done_indices = newly_done_mask.nonzero()
|
||||
|
||||
if newly_done_indices.numel() > 0:
|
||||
for i in newly_done_indices:
|
||||
session.y_results[i] = session.y[
|
||||
i, session.y_len : session.y_len + idx
|
||||
].squeeze(0)
|
||||
session.completed[newly_done_indices] = True
|
||||
|
||||
if torch.all(session.completed).item():
|
||||
logger.info(
|
||||
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
|
||||
)
|
||||
logger.info(
|
||||
f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s"
|
||||
)
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx + 1) * session.bsz / infer_time
|
||||
break
|
||||
break
|
||||
|
||||
with torch_profiler.record("NextPos"), timer("Torch.NextPos", debug=debug):
|
||||
y_emb = decoder.ar_audio_embedding(samples)
|
||||
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
|
||||
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
|
||||
for i in range(session.bsz):
|
||||
if not session.completed[i].item():
|
||||
session.y_results[i] = session.y[
|
||||
[i], session.y_len : session.y_len + idx
|
||||
].squeeze(0)
|
||||
session.completed[i] = True
|
||||
logger.error("Bad Full Prediction")
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx + 1) * session.bsz / infer_time
|
||||
break
|
||||
|
||||
if idx == 10:
|
||||
torch_profiler.end()
|
||||
with torch_profiler.record("NextPos"), timer("Torch.NextPos", debug=debug):
|
||||
y_emb = decoder.ar_audio_embedding(samples)
|
||||
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
|
||||
|
||||
if request.use_cuda_graph and self.graphcache.is_applicable:
|
||||
self.graphcache.release_graph(session)
|
||||
if idx == 10:
|
||||
torch_profiler.end()
|
||||
finally:
|
||||
if (
|
||||
request.use_cuda_graph
|
||||
and self.graphcache.is_applicable
|
||||
and torch.cuda.is_available()
|
||||
and torch.version.cuda is not None
|
||||
and os.environ.get("CUDAGraph", "1") != "0"
|
||||
):
|
||||
self.graphcache.release_graph(graph_state)
|
||||
|
||||
match decoder.device.type:
|
||||
case "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
case "mps":
|
||||
torch.mps.empty_cache()
|
||||
case "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
case "mtia":
|
||||
torch.mtia.empty_cache()
|
||||
case "cpu":
|
||||
gc.collect(1)
|
||||
|
||||
return session.y_results[: request.valid_length], infer_speed, infer_time, (idx + 1) * session.bsz
|
||||
|
||||
@ -205,14 +233,3 @@ class T2SEngine(T2SEngineProtocol):
|
||||
logger.info(f"Quantized by {quantize_mode} Quantization")
|
||||
|
||||
return decoder.eval()
|
||||
|
||||
def init_cache(self):
|
||||
assert self.decoder_model
|
||||
|
||||
module_name = self.decoder_model.__class__.__module__
|
||||
module = sys.modules.get(module_name)
|
||||
assert module
|
||||
|
||||
target_class: type[CUDAGraphCacheABC] = getattr(module, "CUDAGraphCache")
|
||||
|
||||
return target_class(self.decoder_model)
|
||||
|
||||
@ -8,11 +8,11 @@ import math
|
||||
import os
|
||||
import pickle
|
||||
import platform
|
||||
import random
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Literal, MutableSequence
|
||||
|
||||
import torch
|
||||
@ -479,6 +479,8 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
|
||||
max_seq_length=max_seq_length,
|
||||
)
|
||||
|
||||
self.graph_cache_class: type[CUDAGraphCacheABC]
|
||||
|
||||
self.bits: int
|
||||
self.group_size: int
|
||||
|
||||
@ -639,56 +641,78 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
|
||||
raise ValueError(f"Unsupported Quantization Mode for PyTorch: {mode}")
|
||||
|
||||
|
||||
class CUDAGraphCacheABC(ABC):
|
||||
class CUDAGraphStateABC(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
bsz: int,
|
||||
decoder: T2SDecoderABC,
|
||||
) -> None:
|
||||
self.is_applicable: bool
|
||||
self.bsz = bsz
|
||||
self.embedding_dim = decoder.embedding_dim
|
||||
self.dtype = decoder.bert_proj.bias.dtype
|
||||
self.device = decoder.device
|
||||
|
||||
if torch.cuda.is_available() and self.is_applicable:
|
||||
self.device: torch.device = decoder.device
|
||||
self.dtype = decoder.bert_proj.bias.dtype
|
||||
self.decoder: T2SDecoderABC = decoder
|
||||
self.graph: torch.cuda.CUDAGraph | None = None
|
||||
self.stream: torch.cuda.Stream | None = None
|
||||
|
||||
self.assigned: bool = False
|
||||
self.xy_pos = torch.rand(size=(self.bsz, 1, self.embedding_dim), device=self.device).to(self.dtype)
|
||||
self.kv_cache: MutableSequence[KVCacheProtocol] = decoder.init_cache(bsz)
|
||||
self.xy_dec = self.xy_pos.clone()
|
||||
self.input_pos = torch.tensor([10] * self.bsz, device=self.device).to(torch.int32)
|
||||
|
||||
self.decoder: T2SDecoderABC = decoder
|
||||
self.kv_cache: MutableSequence[KVCacheProtocol] = decoder.init_cache(decoder.max_batch_size)
|
||||
self.xy_pos = torch.rand(size=(decoder.max_batch_size, 1, decoder.embedding_dim), device=self.device).to(
|
||||
self.dtype
|
||||
)
|
||||
self.xy_dec = self.xy_pos.clone()
|
||||
self.capture()
|
||||
|
||||
self.input_pos = torch.tensor([10] * decoder.max_batch_size, device=self.device).int()
|
||||
self.graph: torch.cuda.CUDAGraph | None = None
|
||||
self.stream: torch.cuda.Stream | None
|
||||
@abstractmethod
|
||||
def capture(self): ...
|
||||
|
||||
self.id: int = random.randint(1, 2**32 - 1)
|
||||
def assign_graph(self, session: T2SSession) -> CUDAGraphStateABC:
|
||||
assert self.graph
|
||||
session.graph = self.graph
|
||||
session.stream = self.stream
|
||||
|
||||
def assign_graph(self, session: T2SSession):
|
||||
if self.graph is None:
|
||||
args, kwds = self.decoder.pre_forward(session)
|
||||
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
||||
self.graph = graph
|
||||
self.stream = torch.cuda.Stream()
|
||||
session.xy_pos_ = self.xy_pos
|
||||
session.xy_dec_ = self.xy_dec
|
||||
session.input_pos = self.input_pos.copy_(session.input_pos)
|
||||
|
||||
if self.assigned is False:
|
||||
self.get_cache_graph(session)
|
||||
session.id = self.id
|
||||
self.assigned = True
|
||||
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
|
||||
cache.sync_cache(cache_)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class CUDAGraphCacheABC(ABC):
|
||||
is_applicable: bool
|
||||
|
||||
def __init__(self, decoder: T2SDecoderABC, cache_size: int = 5) -> None:
|
||||
self.decoder = decoder
|
||||
self.max_batch_size = decoder.max_batch_size
|
||||
self.cache_size = cache_size
|
||||
|
||||
self.graph_cache: dict[int, Queue[CUDAGraphStateABC]] = {}
|
||||
|
||||
if torch.cuda.is_available() and torch.version.cuda is not None and os.environ.get("CUDAGraph", "1") != "0":
|
||||
self.create_graph_cache(1)
|
||||
|
||||
def __getitem__(self, bsz: int) -> CUDAGraphStateABC:
|
||||
if self.is_applicable:
|
||||
assert bsz <= self.max_batch_size
|
||||
if self.graph_cache.get(bsz) is None:
|
||||
self.create_graph_cache(bsz)
|
||||
return self.graph_cache[bsz].get()
|
||||
else:
|
||||
self.capture_new_graph(session)
|
||||
raise RuntimeError("CUDAGraph Is Not Applicable")
|
||||
|
||||
@abstractmethod
|
||||
def release_graph(self, session: T2SSession): ...
|
||||
def create_graph_cache(self, bsz: int): ...
|
||||
|
||||
@abstractmethod
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def capture_new_graph(self, session: T2SSession):
|
||||
pass
|
||||
def release_graph(self, graph_state: CUDAGraphStateABC | None):
|
||||
if graph_state is None:
|
||||
return
|
||||
bsz = graph_state.bsz
|
||||
assert bsz <= self.max_batch_size
|
||||
assert self.graph_cache.get(bsz) is not None
|
||||
self.graph_cache[bsz].put(graph_state)
|
||||
|
||||
|
||||
class TorchProfiler:
|
||||
|
||||
@ -955,7 +955,7 @@ class TTS:
|
||||
"""
|
||||
########## variables initialization ###########
|
||||
torch.set_grad_enabled(False)
|
||||
ttfb_time = time.perf_counter()
|
||||
ttft_time = time.perf_counter()
|
||||
self.stop_flag: bool = False
|
||||
text: str = inputs.get("text", "")
|
||||
text_lang: str = inputs.get("text_lang", "")
|
||||
@ -1141,7 +1141,9 @@ class TTS:
|
||||
temperature=temperature,
|
||||
repetition_penalty=repetition_penalty,
|
||||
debug=os.environ.get("DEBUG", "0") == "1",
|
||||
use_cuda_graph=torch.cuda.is_available(),
|
||||
use_cuda_graph=torch.cuda.is_available()
|
||||
and torch.version.cuda is not None
|
||||
and os.environ.get("CUDAGraph", "1") != "0",
|
||||
)
|
||||
t2s_result = self.t2s_model.generate(t2s_request)
|
||||
|
||||
@ -1151,7 +1153,7 @@ class TTS:
|
||||
|
||||
pred_semantic_list = t2s_result.result
|
||||
assert pred_semantic_list
|
||||
pred_semantic_list = [semantic.squeeze(0) for semantic in pred_semantic_list]
|
||||
pred_semantic_list = [semantic.squeeze(0).to(self.configs.device) for semantic in pred_semantic_list]
|
||||
infer_len.append(t2s_result.total_tokens)
|
||||
infer_time.append(t2s_result.infer_speed[-1])
|
||||
|
||||
@ -1243,7 +1245,7 @@ class TTS:
|
||||
)
|
||||
batch_audio_fragment.append(audio_fragment)
|
||||
if idx == 0:
|
||||
ttfb_time = time.perf_counter() - ttfb_time
|
||||
ttft_time = time.perf_counter() - ttft_time
|
||||
t5 = time.perf_counter()
|
||||
t_45 += t5 - t4
|
||||
if return_fragment:
|
||||
@ -1307,10 +1309,10 @@ class TTS:
|
||||
console.print(f">> Infer Speed: {infer_speed_avg:.2f} Token/s")
|
||||
console.print(f">> RTF: {rtf_value:.2f}")
|
||||
|
||||
if ttfb_time > 2:
|
||||
console.print(f">> TTFB: {ttfb_time:.3f} s")
|
||||
if ttft_time > 2:
|
||||
console.print(f">> TTFT: {ttft_time:.3f} s")
|
||||
else:
|
||||
console.print(f">> TTFB: {ttfb_time * 1000:.3f} ms")
|
||||
console.print(f">> TTFT: {ttft_time * 1000:.3f} ms")
|
||||
|
||||
self.empty_cache()
|
||||
|
||||
|
||||
@ -708,7 +708,7 @@ async def get_tts_wav(
|
||||
torch.set_grad_enabled(False)
|
||||
progress(0, desc="Inferencing...")
|
||||
debug = os.getenv("DEBUG") == "1"
|
||||
ttfb_time = ttime()
|
||||
ttft_time = ttime()
|
||||
|
||||
if ref_wav_path:
|
||||
pass
|
||||
@ -829,7 +829,9 @@ async def get_tts_wav(
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=1500,
|
||||
use_cuda_graph=torch.cuda.is_available(), # Try to use CUDA Graph for all backend, fallback to normal if not applicapble
|
||||
use_cuda_graph=torch.cuda.is_available()
|
||||
and torch.version.cuda is not None
|
||||
and os.environ.get("CUDAGraph", "1") != "0",
|
||||
debug=debug,
|
||||
)
|
||||
assert t2s_engine
|
||||
@ -938,7 +940,7 @@ async def get_tts_wav(
|
||||
wav_gen = vocoder_model(cfm_res) # type: ignore
|
||||
audio = wav_gen[0][0]
|
||||
if i_text == 0:
|
||||
ttfb_time = ttime() - ttfb_time
|
||||
ttft_time = ttime() - ttft_time
|
||||
max_audio = torch.abs(audio).max() # 简单防止16bit爆音
|
||||
if max_audio > 1:
|
||||
audio = audio / max_audio
|
||||
@ -980,12 +982,12 @@ async def get_tts_wav(
|
||||
gr.Info(f"{infer_speed_avg:.2f} Token/s", title="Infer Speed")
|
||||
gr.Info(f"{rtf_value:.2f}", title="RTF")
|
||||
|
||||
if ttfb_time > 2:
|
||||
console.print(f">> TTFB: {ttfb_time:.3f} s")
|
||||
gr.Info(f"{ttfb_time:.3f} s", title="TTFB")
|
||||
if ttft_time > 2:
|
||||
console.print(f">> TTFT: {ttft_time:.3f} s")
|
||||
gr.Info(f"{ttft_time:.3f} s", title="TTFT")
|
||||
else:
|
||||
console.print(f">> TTFB: {ttfb_time * 1000:.3f} ms")
|
||||
gr.Info(f"{ttfb_time * 1000:.3f} ms", title="TTFB")
|
||||
console.print(f">> TTFT: {ttft_time * 1000:.3f} ms")
|
||||
gr.Info(f"{ttft_time * 1000:.3f} ms", title="TTFT")
|
||||
|
||||
progress(1, desc="Done")
|
||||
yield opt_sr, (audio_opt_n * 32767).astype(np.int16)
|
||||
|
||||
@ -958,7 +958,7 @@ class SynthesizerTrn(nn.Module):
|
||||
ge = self.prelu(ge)
|
||||
return ge
|
||||
|
||||
if type(refer) == list:
|
||||
if isinstance(refer, list):
|
||||
ges = []
|
||||
for idx, _refer in enumerate(refer):
|
||||
ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)
|
||||
|
||||
@ -51,7 +51,7 @@ Unseen speakers few-shot fine-tuning demo:
|
||||
|
||||
## Infer Speed
|
||||
|
||||
| Device | RTF | TTFB | Batch Size | Backend |
|
||||
| Device | RTF | TTFT | Batch Size | Backend |
|
||||
| :---------: | :---: | :----: | :--------: | :-------------------------: |
|
||||
| RTX 5090 | 0.05 | 150 ms | 1 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
|
||||
@ -138,13 +138,13 @@ pip install -r requirements.txt
|
||||
|
||||
```bash
|
||||
conda activate GPTSoVits
|
||||
conda install ffmpeg=7 -c conda-forge
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
##### Ubuntu/Debian Users
|
||||
|
||||
```bash
|
||||
sudo apt install ffmpeg=7
|
||||
sudo apt install ffmpeg
|
||||
sudo apt install libsox-dev
|
||||
```
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@
|
||||
|
||||
## 推理速度
|
||||
|
||||
| Device | RTF | TTFB | Batch Size | Backend |
|
||||
| Device | RTF | TTFT | Batch Size | Backend |
|
||||
| :---------: | :---: | :----: | :--------: | :-------------------------: |
|
||||
| RTX 5090 | 0.05 | 150 ms | 1 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
|
||||
@ -136,13 +136,13 @@ pip install -r requirements.txt
|
||||
|
||||
```bash
|
||||
conda activate GPTSoVits
|
||||
conda install ffmpeg=7 -c conda-forge
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
##### Ubuntu/Debian 用户
|
||||
|
||||
```bash
|
||||
sudo apt install ffmpeg=7
|
||||
sudo apt install ffmpeg
|
||||
sudo apt install libsox-dev
|
||||
```
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@
|
||||
|
||||
## 推論速度
|
||||
|
||||
| Device | RTF | TTFB | Batch Size | Backend |
|
||||
| Device | RTF | TTFT | Batch Size | Backend |
|
||||
| :---------: | :---: | :----: | :--------: | :-------------------------: |
|
||||
| RTX 5090 | 0.05 | 150 ms | 1 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
|
||||
@ -126,13 +126,13 @@ pip install -r requirements.txt
|
||||
|
||||
```bash
|
||||
conda activate GPTSoVits
|
||||
conda install ffmpeg=7 -c conda-forge
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
##### Ubuntu/Debian ユーザー
|
||||
|
||||
```bash
|
||||
sudo apt install ffmpeg=7
|
||||
sudo apt install ffmpeg
|
||||
sudo apt install libsox-dev
|
||||
```
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@
|
||||
|
||||
## 추론 속도
|
||||
|
||||
| Device | RTF | TTFB | Batch Size | Backend |
|
||||
| Device | RTF | TTFT | Batch Size | Backend |
|
||||
| :---------: | :---: | :----: | :--------: | :-------------------------: |
|
||||
| RTX 5090 | 0.05 | 150 ms | 1 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
|
||||
@ -132,13 +132,13 @@ pip install -r requirements.txt
|
||||
|
||||
```bash
|
||||
conda activate GPTSoVits
|
||||
conda install ffmpeg=7 -c conda-forge
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
##### Ubuntu/Debian 사용자
|
||||
|
||||
```bash
|
||||
sudo apt install ffmpeg=7
|
||||
sudo apt install ffmpeg
|
||||
sudo apt install libsox-dev
|
||||
```
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@ Görünmeyen konuşmacılar birkaç örnekli ince ayar demosu:
|
||||
|
||||
## çıkarım hızı
|
||||
|
||||
| Device | RTF | TTFB | Batch Size | Backend |
|
||||
| Device | RTF | TTFT | Batch Size | Backend |
|
||||
| :---------: | :---: | :----: | :--------: | :-------------------------: |
|
||||
| RTX 5090 | 0.05 | 150 ms | 1 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
|
||||
@ -132,13 +132,13 @@ pip install -r requirements.txt
|
||||
|
||||
```bash
|
||||
conda activate GPTSoVits
|
||||
conda install ffmpeg=7 -c conda-forge
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
##### Ubuntu/Debian Kullanıcıları
|
||||
|
||||
```bash
|
||||
sudo apt install ffmpeg=7
|
||||
sudo apt install ffmpeg
|
||||
sudo apt install libsox-dev
|
||||
```
|
||||
|
||||
|
||||
@ -227,7 +227,7 @@ else
|
||||
fi
|
||||
|
||||
echo -e "${INFO}Installing FFmpeg & CMake..."
|
||||
run_conda_quiet ffmpeg=7 cmake make
|
||||
run_conda_quiet ffmpeg cmake make
|
||||
echo -e "${SUCCESS}FFmpeg & CMake Installed"
|
||||
|
||||
echo -e "${INFO}Installing unzip..."
|
||||
|
||||
10
test.py
10
test.py
@ -552,7 +552,7 @@ def get_tts_wav(
|
||||
):
|
||||
torch.set_grad_enabled(False)
|
||||
debug = os.getenv("DEBUG") == "1"
|
||||
ttfb_time = ttime()
|
||||
ttft_time = ttime()
|
||||
|
||||
if ref_wav_path:
|
||||
pass
|
||||
@ -698,7 +698,7 @@ def get_tts_wav(
|
||||
)[0][0] # type: ignore
|
||||
|
||||
if i_text == 0:
|
||||
ttfb_time = ttime() - ttfb_time
|
||||
ttft_time = ttime() - ttft_time
|
||||
max_audio = torch.abs(audio).max() # 简单防止16bit爆音
|
||||
if max_audio > 1:
|
||||
audio = audio / max_audio
|
||||
@ -729,10 +729,10 @@ def get_tts_wav(
|
||||
console.print(f">> Infer Speed: {infer_speed_avg:.2f} Token/s")
|
||||
console.print(f">> RTF: {rtf_value:.2f}")
|
||||
|
||||
if ttfb_time > 2:
|
||||
console.print(f">> TTFB: {ttfb_time:.3f} s")
|
||||
if ttft_time > 2:
|
||||
console.print(f">> TTFT: {ttft_time:.3f} s")
|
||||
else:
|
||||
console.print(f">> TTFB: {ttfb_time * 1000:.3f} ms")
|
||||
console.print(f">> TTFT: {ttft_time * 1000:.3f} ms")
|
||||
|
||||
yield opt_sr, (audio_opt_n * 32767).astype(np.int16)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user