This commit is contained in:
XXXXRT666 2025-10-23 01:54:40 +01:00
parent 0ec73aac3c
commit 5da25d4f18
16 changed files with 313 additions and 264 deletions

View File

@ -1,4 +1,5 @@
import importlib.util
import os
import torch
@ -25,8 +26,12 @@ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
cpu_count = os.cpu_count() or 1
torch.set_num_threads(cpu_count)
torch.set_num_interop_threads(cpu_count)
backends = ["torch_varlen"]
if torch.cuda.is_available():
if torch.cuda.is_available() and torch.version.cuda is not None:
backends.append("torch_static_cuda_graph")
# if importlib.util.find_spec("sageattention") is not None:
# for i in range(torch.cuda.device_count()):
@ -44,7 +49,7 @@ if torch.cuda.is_available():
# backends.append("mps_flash_attn_varlen")
BLACKWELL = False
if torch.cuda.is_available():
if torch.cuda.is_available() and torch.version.cuda is not None:
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
sm_version = major + minor / 10.0

View File

@ -12,6 +12,7 @@ from ..structs import T2SSession
from ..t2s_model_abc import (
AttentionABC,
CUDAGraphCacheABC,
CUDAGraphStateABC,
FeedForward,
KVCacheNHD,
KVCacheProtocol,
@ -114,6 +115,8 @@ class T2SDecoder(T2SDecoderABC):
self.kv_class = KVCacheNHD
self.graph_cache_class = CUDAGraphCache
def compile(self, *args, **kwds):
pass
@ -124,40 +127,39 @@ class T2SDecoder(T2SDecoderABC):
return super().pre_forward(session)
class CUDAGraphCache(CUDAGraphCacheABC):
class CUDAGraphState(CUDAGraphStateABC):
applicable: bool = True
def __init__(
self,
decoder: T2SDecoder,
bsz: int,
decoder: T2SDecoderABC,
) -> None:
self.is_applicable = True
super().__init__(decoder)
super().__init__(bsz, decoder)
def release_graph(self, session: T2SSession):
if session.id == self.id:
self.assigned = False
else:
assert session.graph
session.graph.reset()
del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache
def capture(self):
graph = self.decoder.capture(
self.input_pos,
self.xy_pos,
self.xy_dec,
self.kv_cache,
)
self.graph = graph
self.stream = torch.cuda.Stream()
def get_cache_graph(self, session: T2SSession):
assert self.graph
session.graph = self.graph
session.stream = self.stream
session.xy_pos_ = self.xy_pos
session.xy_dec_ = self.xy_dec
session.input_pos = self.input_pos.copy_(session.input_pos)
class CUDAGraphCache(CUDAGraphCacheABC):
is_applicable = True
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
cache.sync_cache(cache_)
def __init__(
self,
decoder,
cache_size: int = 5,
) -> None:
super().__init__(decoder, cache_size)
def capture_new_graph(self, session: T2SSession):
session.xy_pos_ = self.xy_pos.clone()
session.xy_dec_ = self.xy_dec.clone()
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
args, kwds = self.decoder.pre_forward(session)
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
session.graph = graph
session.stream = torch.cuda.Stream() # type: ignore
def create_graph_cache(self, bsz: int):
for _ in range(self.cache_size):
state = CUDAGraphState(bsz, self.decoder)
state.capture()
self.graph_cache[bsz].put(state)

View File

@ -6,6 +6,7 @@ from ..structs import KVCacheProtocol, T2SSession
from ..t2s_model_abc import (
AttentionABC,
CUDAGraphCacheABC,
CUDAGraphStateABC,
FeedForward,
KVCacheHND,
T2SDecoderABC,
@ -91,6 +92,8 @@ class T2SDecoder(T2SDecoderABC):
self.kv_class = KVCacheHND
self.graph_cache_class = CUDAGraphCache
def pre_forward(self, session: T2SSession):
attn_mask = session.attn_mask
return list(), dict(attn_mask=attn_mask)
@ -111,57 +114,54 @@ class T2SDecoder(T2SDecoderABC):
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
class CUDAGraphState(CUDAGraphStateABC):
applicable: bool = False
def __init__(
self,
bsz: int,
decoder: T2SDecoderABC,
) -> None:
self.attn_mask: Tensor = (
torch.randint(
0,
2,
(bsz, decoder.n_head, 1, decoder.max_seq_length),
)
.bool()
.to(decoder.device)
)
super().__init__(bsz, decoder)
def capture(self):
graph = self.decoder.capture(
self.input_pos,
self.xy_pos,
self.xy_dec,
self.kv_cache,
attn_mask=self.attn_mask,
)
self.graph = graph
self.stream = torch.cuda.Stream()
def assign_graph(self, session: T2SSession):
session.attn_mask = self.attn_mask
return super().assign_graph(session)
class CUDAGraphCache(CUDAGraphCacheABC):
is_applicable = True
def __init__(
self,
decoder,
cache_size: int = 5,
) -> None:
self.is_applicable = True
super().__init__(decoder)
if torch.cuda.is_available():
self.attn_mask = (
torch.randint(0, 2, (decoder.max_batch_size, decoder.n_head, 1, decoder.max_seq_length))
.bool()
.to(self.device, self.dtype)
)
super().__init__(decoder, cache_size)
def release_graph(self, session: T2SSession):
if session.id == self.id:
self.assigned = False
else:
assert session.graph
session.graph.reset()
del (
session.graph,
session.xy_pos_,
session.xy_dec_,
session.input_pos,
session.kv_cache,
session.attn_mask,
)
def get_cache_graph(self, session: T2SSession):
assert self.graph
session.graph = self.graph
session.stream = self.stream
session.xy_pos_ = self.xy_pos
session.xy_dec_ = self.xy_dec
session.input_pos = self.input_pos.copy_(session.input_pos)
session.attn_mask = self.attn_mask
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
cache.sync_cache(cache_)
def capture_new_graph(self, session: T2SSession):
session.xy_pos_ = self.xy_pos.clone()
session.xy_dec_ = self.xy_dec.clone()
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
session.attn_mask = self.attn_mask.clone().copy_(session.attn_mask)
args, kwds = self.decoder.pre_forward(session)
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
session.graph = graph
session.stream = torch.cuda.Stream() # type: ignore
def create_graph_cache(self, bsz: int):
for _ in range(self.cache_size):
state = CUDAGraphState(bsz, self.decoder)
state.capture()
self.graph_cache[bsz].put(state)

View File

@ -99,6 +99,8 @@ class T2SDecoder(T2SDecoderABC):
self.kv_class = KVCacheHNDVarlen
self.graph_cache_class = CUDAGraphCache
def capture(
self,
*args,
@ -127,18 +129,13 @@ class T2SDecoder(T2SDecoderABC):
class CUDAGraphCache(CUDAGraphCacheABC):
is_applicable = False
def __init__(
self,
decoder,
) -> None:
self.is_applicable = False
super().__init__(decoder)
def release_graph(self, session: T2SSession):
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
def get_cache_graph(self, session: T2SSession):
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
def capture_new_graph(self, session: T2SSession):
def create_graph_cache(self, bsz: int) -> NoReturn:
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")

View File

@ -1,6 +1,6 @@
import contextlib
import gc
import os
import sys
import time
import traceback
from importlib import import_module
@ -34,7 +34,7 @@ class T2SEngine(T2SEngineProtocol):
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
# self.decoder_model.compile()
self.graphcache: CUDAGraphCacheABC = self.init_cache()
self.graphcache: CUDAGraphCacheABC = decoder_model.graph_cache_class(self.decoder_model)
def _handle_request(self, request: T2SRequest):
with self.device:
@ -47,6 +47,7 @@ class T2SEngine(T2SEngineProtocol):
infer_speed = 0.0
infer_time = 0.0
idx = 0
graph_state = None
torch_profiler = TorchProfiler(debug)
with (
@ -64,105 +65,132 @@ class T2SEngine(T2SEngineProtocol):
max_token = min(int(1500 - session.input_pos.max()), 1000) * session.bsz
task = progress.add_task("T2S Decoding", total=max_token)
for idx in range(max_token):
progress.update(task, advance=session.bsz)
if idx == 0:
with torch_profiler.record("Prefill"), timer("Torch.Prefill", debug=debug):
session.kv_cache = decoder.init_cache(session.bsz)
t1 = time.perf_counter()
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
xy_dec = xy_dec[batch_idx, None, session.input_pos - 1]
else:
if (
request.use_cuda_graph
and session.graph is None
and self.graphcache.is_applicable
and torch.cuda.is_available()
):
self.graphcache.assign_graph(session)
with torch_profiler.record("Decode"), timer("Torch.Decode", debug=debug):
if session.graph:
assert session.stream
session.stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(session.stream):
session.xy_pos_.copy_(session.xy_pos)
session.graph.replay()
xy_dec = session.xy_dec_.clone()
else:
args, kwds = decoder.pre_forward(session)
xy_dec = decoder.h(
session.input_pos,
session.xy_pos,
session.kv_cache,
*args,
**kwds,
)
with torch.cuda.stream(session.stream) if session.stream is not None else contextlib.nullcontext():
decoder.post_forward(idx, session)
logits = decoder.ar_predict_layer(xy_dec.squeeze(1))
try:
for idx in range(max_token):
progress.update(task, advance=session.bsz)
if idx == 0:
logits[:, -1] = float("-inf")
with torch_profiler.record("Prefill"), timer("Torch.Prefill", debug=debug):
session.kv_cache = decoder.init_cache(session.bsz)
t1 = time.perf_counter()
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
xy_dec = xy_dec[batch_idx, None, session.input_pos - 1]
else:
if (
idx == 1
and request.use_cuda_graph
and self.graphcache.is_applicable
and torch.cuda.is_available()
and torch.version.cuda is not None
and os.environ.get("CUDAGraph", "1") != "0"
):
graph_state = self.graphcache[session.bsz].assign_graph(session)
with torch_profiler.record("Sampling"), timer("Torch.Sampling", debug=debug):
samples = session.sample(
logits=logits,
previous_tokens=session.y[:, : session.y_len + idx],
top_k=request.top_k,
top_p=request.top_p,
repetition_penalty=request.repetition_penalty,
temperature=request.temperature,
)
session.y[batch_idx.reshape(-1, 1), session.y_len + idx] = samples
session.input_pos.add_(1)
with torch_profiler.record("EOS"), timer("Torch.EOS", debug=debug):
argmax_token = torch.argmax(logits, dim=-1)
sample_token = samples.squeeze(1)
EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
newly_done_mask = EOS_mask & (~session.completed)
newly_done_indices = newly_done_mask.nonzero()
if newly_done_indices.numel() > 0:
for i in newly_done_indices:
session.y_results[i] = session.y[i, session.y_len : session.y_len + idx].squeeze(0)
session.completed[newly_done_indices] = True
if torch.all(session.completed).item():
logger.info(
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
)
logger.info(
f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s"
)
infer_time = time.perf_counter() - t1
infer_speed = (idx + 1) * session.bsz / infer_time
break
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
for i in range(session.bsz):
if not session.completed[i].item():
session.y_results[i] = session.y[[i], session.y_len : session.y_len + idx].squeeze(
0
with torch_profiler.record("Decode"), timer("Torch.Decode", debug=debug):
if session.graph:
assert session.stream
session.stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(session.stream):
session.xy_pos_.copy_(session.xy_pos)
session.graph.replay()
xy_dec = session.xy_dec_.clone()
else:
args, kwds = decoder.pre_forward(session)
xy_dec = decoder.h(
session.input_pos,
session.xy_pos,
session.kv_cache,
*args,
**kwds,
)
session.completed[i] = True
logger.error("Bad Full Prediction")
with (
torch.cuda.stream(session.stream)
if session.stream is not None
else contextlib.nullcontext()
):
decoder.post_forward(idx, session)
logits = decoder.ar_predict_layer(xy_dec.squeeze(1))
if idx == 0:
logits[:, -1] = float("-inf")
with torch_profiler.record("Sampling"), timer("Torch.Sampling", debug=debug):
samples = session.sample(
logits=logits,
previous_tokens=session.y[:, : session.y_len + idx],
top_k=request.top_k,
top_p=request.top_p,
repetition_penalty=request.repetition_penalty,
temperature=request.temperature,
)
session.y[batch_idx.reshape(-1, 1), session.y_len + idx] = samples
session.input_pos.add_(1)
with torch_profiler.record("EOS"), timer("Torch.EOS", debug=debug):
argmax_token = torch.argmax(logits, dim=-1)
sample_token = samples.squeeze(1)
EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
newly_done_mask = EOS_mask & (~session.completed)
newly_done_indices = newly_done_mask.nonzero()
if newly_done_indices.numel() > 0:
for i in newly_done_indices:
session.y_results[i] = session.y[
i, session.y_len : session.y_len + idx
].squeeze(0)
session.completed[newly_done_indices] = True
if torch.all(session.completed).item():
logger.info(
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
)
logger.info(
f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s"
)
infer_time = time.perf_counter() - t1
infer_speed = (idx + 1) * session.bsz / infer_time
break
break
with torch_profiler.record("NextPos"), timer("Torch.NextPos", debug=debug):
y_emb = decoder.ar_audio_embedding(samples)
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
for i in range(session.bsz):
if not session.completed[i].item():
session.y_results[i] = session.y[
[i], session.y_len : session.y_len + idx
].squeeze(0)
session.completed[i] = True
logger.error("Bad Full Prediction")
infer_time = time.perf_counter() - t1
infer_speed = (idx + 1) * session.bsz / infer_time
break
if idx == 10:
torch_profiler.end()
with torch_profiler.record("NextPos"), timer("Torch.NextPos", debug=debug):
y_emb = decoder.ar_audio_embedding(samples)
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
if request.use_cuda_graph and self.graphcache.is_applicable:
self.graphcache.release_graph(session)
if idx == 10:
torch_profiler.end()
finally:
if (
request.use_cuda_graph
and self.graphcache.is_applicable
and torch.cuda.is_available()
and torch.version.cuda is not None
and os.environ.get("CUDAGraph", "1") != "0"
):
self.graphcache.release_graph(graph_state)
match decoder.device.type:
case "cuda":
torch.cuda.empty_cache()
case "mps":
torch.mps.empty_cache()
case "xpu":
torch.xpu.empty_cache()
case "mtia":
torch.mtia.empty_cache()
case "cpu":
gc.collect(1)
return session.y_results[: request.valid_length], infer_speed, infer_time, (idx + 1) * session.bsz
@ -205,14 +233,3 @@ class T2SEngine(T2SEngineProtocol):
logger.info(f"Quantized by {quantize_mode} Quantization")
return decoder.eval()
def init_cache(self):
assert self.decoder_model
module_name = self.decoder_model.__class__.__module__
module = sys.modules.get(module_name)
assert module
target_class: type[CUDAGraphCacheABC] = getattr(module, "CUDAGraphCache")
return target_class(self.decoder_model)

View File

@ -8,11 +8,11 @@ import math
import os
import pickle
import platform
import random
import time
from abc import ABC, abstractmethod
from contextlib import nullcontext
from pathlib import Path
from queue import Queue
from typing import Literal, MutableSequence
import torch
@ -479,6 +479,8 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
max_seq_length=max_seq_length,
)
self.graph_cache_class: type[CUDAGraphCacheABC]
self.bits: int
self.group_size: int
@ -639,56 +641,78 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
raise ValueError(f"Unsupported Quantization Mode for PyTorch: {mode}")
class CUDAGraphCacheABC(ABC):
class CUDAGraphStateABC(ABC):
def __init__(
self,
bsz: int,
decoder: T2SDecoderABC,
) -> None:
self.is_applicable: bool
self.bsz = bsz
self.embedding_dim = decoder.embedding_dim
self.dtype = decoder.bert_proj.bias.dtype
self.device = decoder.device
if torch.cuda.is_available() and self.is_applicable:
self.device: torch.device = decoder.device
self.dtype = decoder.bert_proj.bias.dtype
self.decoder: T2SDecoderABC = decoder
self.graph: torch.cuda.CUDAGraph | None = None
self.stream: torch.cuda.Stream | None = None
self.assigned: bool = False
self.xy_pos = torch.rand(size=(self.bsz, 1, self.embedding_dim), device=self.device).to(self.dtype)
self.kv_cache: MutableSequence[KVCacheProtocol] = decoder.init_cache(bsz)
self.xy_dec = self.xy_pos.clone()
self.input_pos = torch.tensor([10] * self.bsz, device=self.device).to(torch.int32)
self.decoder: T2SDecoderABC = decoder
self.kv_cache: MutableSequence[KVCacheProtocol] = decoder.init_cache(decoder.max_batch_size)
self.xy_pos = torch.rand(size=(decoder.max_batch_size, 1, decoder.embedding_dim), device=self.device).to(
self.dtype
)
self.xy_dec = self.xy_pos.clone()
self.capture()
self.input_pos = torch.tensor([10] * decoder.max_batch_size, device=self.device).int()
self.graph: torch.cuda.CUDAGraph | None = None
self.stream: torch.cuda.Stream | None
@abstractmethod
def capture(self): ...
self.id: int = random.randint(1, 2**32 - 1)
def assign_graph(self, session: T2SSession) -> CUDAGraphStateABC:
assert self.graph
session.graph = self.graph
session.stream = self.stream
def assign_graph(self, session: T2SSession):
if self.graph is None:
args, kwds = self.decoder.pre_forward(session)
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
self.graph = graph
self.stream = torch.cuda.Stream()
session.xy_pos_ = self.xy_pos
session.xy_dec_ = self.xy_dec
session.input_pos = self.input_pos.copy_(session.input_pos)
if self.assigned is False:
self.get_cache_graph(session)
session.id = self.id
self.assigned = True
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
cache.sync_cache(cache_)
return self
class CUDAGraphCacheABC(ABC):
is_applicable: bool
def __init__(self, decoder: T2SDecoderABC, cache_size: int = 5) -> None:
self.decoder = decoder
self.max_batch_size = decoder.max_batch_size
self.cache_size = cache_size
self.graph_cache: dict[int, Queue[CUDAGraphStateABC]] = {}
if torch.cuda.is_available() and torch.version.cuda is not None and os.environ.get("CUDAGraph", "1") != "0":
self.create_graph_cache(1)
def __getitem__(self, bsz: int) -> CUDAGraphStateABC:
if self.is_applicable:
assert bsz <= self.max_batch_size
if self.graph_cache.get(bsz) is None:
self.create_graph_cache(bsz)
return self.graph_cache[bsz].get()
else:
self.capture_new_graph(session)
raise RuntimeError("CUDAGraph Is Not Applicable")
@abstractmethod
def release_graph(self, session: T2SSession): ...
def create_graph_cache(self, bsz: int): ...
@abstractmethod
def get_cache_graph(self, session: T2SSession):
pass
@abstractmethod
def capture_new_graph(self, session: T2SSession):
pass
def release_graph(self, graph_state: CUDAGraphStateABC | None):
if graph_state is None:
return
bsz = graph_state.bsz
assert bsz <= self.max_batch_size
assert self.graph_cache.get(bsz) is not None
self.graph_cache[bsz].put(graph_state)
class TorchProfiler:

View File

@ -955,7 +955,7 @@ class TTS:
"""
########## variables initialization ###########
torch.set_grad_enabled(False)
ttfb_time = time.perf_counter()
ttft_time = time.perf_counter()
self.stop_flag: bool = False
text: str = inputs.get("text", "")
text_lang: str = inputs.get("text_lang", "")
@ -1141,7 +1141,9 @@ class TTS:
temperature=temperature,
repetition_penalty=repetition_penalty,
debug=os.environ.get("DEBUG", "0") == "1",
use_cuda_graph=torch.cuda.is_available(),
use_cuda_graph=torch.cuda.is_available()
and torch.version.cuda is not None
and os.environ.get("CUDAGraph", "1") != "0",
)
t2s_result = self.t2s_model.generate(t2s_request)
@ -1151,7 +1153,7 @@ class TTS:
pred_semantic_list = t2s_result.result
assert pred_semantic_list
pred_semantic_list = [semantic.squeeze(0) for semantic in pred_semantic_list]
pred_semantic_list = [semantic.squeeze(0).to(self.configs.device) for semantic in pred_semantic_list]
infer_len.append(t2s_result.total_tokens)
infer_time.append(t2s_result.infer_speed[-1])
@ -1243,7 +1245,7 @@ class TTS:
)
batch_audio_fragment.append(audio_fragment)
if idx == 0:
ttfb_time = time.perf_counter() - ttfb_time
ttft_time = time.perf_counter() - ttft_time
t5 = time.perf_counter()
t_45 += t5 - t4
if return_fragment:
@ -1307,10 +1309,10 @@ class TTS:
console.print(f">> Infer Speed: {infer_speed_avg:.2f} Token/s")
console.print(f">> RTF: {rtf_value:.2f}")
if ttfb_time > 2:
console.print(f">> TTFB: {ttfb_time:.3f} s")
if ttft_time > 2:
console.print(f">> TTFT: {ttft_time:.3f} s")
else:
console.print(f">> TTFB: {ttfb_time * 1000:.3f} ms")
console.print(f">> TTFT: {ttft_time * 1000:.3f} ms")
self.empty_cache()

View File

@ -708,7 +708,7 @@ async def get_tts_wav(
torch.set_grad_enabled(False)
progress(0, desc="Inferencing...")
debug = os.getenv("DEBUG") == "1"
ttfb_time = ttime()
ttft_time = ttime()
if ref_wav_path:
pass
@ -829,7 +829,9 @@ async def get_tts_wav(
top_p=top_p,
temperature=temperature,
early_stop_num=1500,
use_cuda_graph=torch.cuda.is_available(), # Try to use CUDA Graph for all backend, fallback to normal if not applicapble
use_cuda_graph=torch.cuda.is_available()
and torch.version.cuda is not None
and os.environ.get("CUDAGraph", "1") != "0",
debug=debug,
)
assert t2s_engine
@ -938,7 +940,7 @@ async def get_tts_wav(
wav_gen = vocoder_model(cfm_res) # type: ignore
audio = wav_gen[0][0]
if i_text == 0:
ttfb_time = ttime() - ttfb_time
ttft_time = ttime() - ttft_time
max_audio = torch.abs(audio).max() # 简单防止16bit爆音
if max_audio > 1:
audio = audio / max_audio
@ -980,12 +982,12 @@ async def get_tts_wav(
gr.Info(f"{infer_speed_avg:.2f} Token/s", title="Infer Speed")
gr.Info(f"{rtf_value:.2f}", title="RTF")
if ttfb_time > 2:
console.print(f">> TTFB: {ttfb_time:.3f} s")
gr.Info(f"{ttfb_time:.3f} s", title="TTFB")
if ttft_time > 2:
console.print(f">> TTFT: {ttft_time:.3f} s")
gr.Info(f"{ttft_time:.3f} s", title="TTFT")
else:
console.print(f">> TTFB: {ttfb_time * 1000:.3f} ms")
gr.Info(f"{ttfb_time * 1000:.3f} ms", title="TTFB")
console.print(f">> TTFT: {ttft_time * 1000:.3f} ms")
gr.Info(f"{ttft_time * 1000:.3f} ms", title="TTFT")
progress(1, desc="Done")
yield opt_sr, (audio_opt_n * 32767).astype(np.int16)

View File

@ -958,7 +958,7 @@ class SynthesizerTrn(nn.Module):
ge = self.prelu(ge)
return ge
if type(refer) == list:
if isinstance(refer, list):
ges = []
for idx, _refer in enumerate(refer):
ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)

View File

@ -51,7 +51,7 @@ Unseen speakers few-shot fine-tuning demo:
## Infer Speed
| Device | RTF | TTFB | Batch Size | Backend |
| Device | RTF | TTFT | Batch Size | Backend |
| :---------: | :---: | :----: | :--------: | :-------------------------: |
| RTX 5090 | 0.05 | 150 ms | 1 | Flash Attn Varlen CUDAGraph |
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
@ -138,13 +138,13 @@ pip install -r requirements.txt
```bash
conda activate GPTSoVits
conda install ffmpeg=7 -c conda-forge
conda install ffmpeg -c conda-forge
```
##### Ubuntu/Debian Users
```bash
sudo apt install ffmpeg=7
sudo apt install ffmpeg
sudo apt install libsox-dev
```

View File

@ -51,7 +51,7 @@
## 推理速度
| Device | RTF | TTFB | Batch Size | Backend |
| Device | RTF | TTFT | Batch Size | Backend |
| :---------: | :---: | :----: | :--------: | :-------------------------: |
| RTX 5090 | 0.05 | 150 ms | 1 | Flash Attn Varlen CUDAGraph |
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
@ -136,13 +136,13 @@ pip install -r requirements.txt
```bash
conda activate GPTSoVits
conda install ffmpeg=7 -c conda-forge
conda install ffmpeg -c conda-forge
```
##### Ubuntu/Debian 用户
```bash
sudo apt install ffmpeg=7
sudo apt install ffmpeg
sudo apt install libsox-dev
```

View File

@ -51,7 +51,7 @@
## 推論速度
| Device | RTF | TTFB | Batch Size | Backend |
| Device | RTF | TTFT | Batch Size | Backend |
| :---------: | :---: | :----: | :--------: | :-------------------------: |
| RTX 5090 | 0.05 | 150 ms | 1 | Flash Attn Varlen CUDAGraph |
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
@ -126,13 +126,13 @@ pip install -r requirements.txt
```bash
conda activate GPTSoVits
conda install ffmpeg=7 -c conda-forge
conda install ffmpeg -c conda-forge
```
##### Ubuntu/Debian ユーザー
```bash
sudo apt install ffmpeg=7
sudo apt install ffmpeg
sudo apt install libsox-dev
```

View File

@ -51,7 +51,7 @@
## 추론 속도
| Device | RTF | TTFB | Batch Size | Backend |
| Device | RTF | TTFT | Batch Size | Backend |
| :---------: | :---: | :----: | :--------: | :-------------------------: |
| RTX 5090 | 0.05 | 150 ms | 1 | Flash Attn Varlen CUDAGraph |
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
@ -132,13 +132,13 @@ pip install -r requirements.txt
```bash
conda activate GPTSoVits
conda install ffmpeg=7 -c conda-forge
conda install ffmpeg -c conda-forge
```
##### Ubuntu/Debian 사용자
```bash
sudo apt install ffmpeg=7
sudo apt install ffmpeg
sudo apt install libsox-dev
```

View File

@ -51,7 +51,7 @@ Görünmeyen konuşmacılar birkaç örnekli ince ayar demosu:
## çıkarım hızı
| Device | RTF | TTFB | Batch Size | Backend |
| Device | RTF | TTFT | Batch Size | Backend |
| :---------: | :---: | :----: | :--------: | :-------------------------: |
| RTX 5090 | 0.05 | 150 ms | 1 | Flash Attn Varlen CUDAGraph |
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
@ -132,13 +132,13 @@ pip install -r requirements.txt
```bash
conda activate GPTSoVits
conda install ffmpeg=7 -c conda-forge
conda install ffmpeg -c conda-forge
```
##### Ubuntu/Debian Kullanıcıları
```bash
sudo apt install ffmpeg=7
sudo apt install ffmpeg
sudo apt install libsox-dev
```

View File

@ -227,7 +227,7 @@ else
fi
echo -e "${INFO}Installing FFmpeg & CMake..."
run_conda_quiet ffmpeg=7 cmake make
run_conda_quiet ffmpeg cmake make
echo -e "${SUCCESS}FFmpeg & CMake Installed"
echo -e "${INFO}Installing unzip..."

10
test.py
View File

@ -552,7 +552,7 @@ def get_tts_wav(
):
torch.set_grad_enabled(False)
debug = os.getenv("DEBUG") == "1"
ttfb_time = ttime()
ttft_time = ttime()
if ref_wav_path:
pass
@ -698,7 +698,7 @@ def get_tts_wav(
)[0][0] # type: ignore
if i_text == 0:
ttfb_time = ttime() - ttfb_time
ttft_time = ttime() - ttft_time
max_audio = torch.abs(audio).max() # 简单防止16bit爆音
if max_audio > 1:
audio = audio / max_audio
@ -729,10 +729,10 @@ def get_tts_wav(
console.print(f">> Infer Speed: {infer_speed_avg:.2f} Token/s")
console.print(f">> RTF: {rtf_value:.2f}")
if ttfb_time > 2:
console.print(f">> TTFB: {ttfb_time:.3f} s")
if ttft_time > 2:
console.print(f">> TTFT: {ttft_time:.3f} s")
else:
console.print(f">> TTFB: {ttfb_time * 1000:.3f} ms")
console.print(f">> TTFT: {ttft_time * 1000:.3f} ms")
yield opt_sr, (audio_opt_n * 32767).astype(np.int16)