This commit is contained in:
XXXXRT666 2025-10-24 05:24:25 +01:00
parent 99fba5154e
commit 8a2448fa3c
3 changed files with 19 additions and 14 deletions

View File

@ -51,6 +51,8 @@ class T2SEngine(T2SEngineProtocol):
infer_time = 0.0
idx = 0
graph_state = None
bsz = session.bsz
result: list[torch.Tensor] = []
torch_profiler = TorchProfiler(debug)
with (
@ -65,15 +67,15 @@ class T2SEngine(T2SEngineProtocol):
) as progress,
):
torch_profiler.start()
max_token = min(int(1500 - session.input_pos.max()), 1000) * session.bsz
max_token = min(int(1500 - session.input_pos.max()), 1000) * bsz
task = progress.add_task("T2S Decoding", total=max_token)
try:
for idx in range(max_token):
progress.update(task, advance=session.bsz)
progress.update(task, advance=bsz)
if idx == 0:
with torch_profiler.record("Prefill"), timer("Torch.Prefill", debug=debug):
session.kv_cache = decoder.init_cache(session.bsz)
session.kv_cache = decoder.init_cache(bsz)
t1 = time.perf_counter()
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
xy_dec = xy_dec[batch_idx, None, session.input_pos - 1]
@ -86,7 +88,7 @@ class T2SEngine(T2SEngineProtocol):
and torch.version.cuda is not None
and os.environ.get("CUDAGraph", "1") != "0"
):
graph_state = self.graphcache[session.bsz].assign_graph(session)
graph_state = self.graphcache[bsz].assign_graph(session)
with torch_profiler.record("Decode"), timer("Torch.Decode", debug=debug):
if session.graph:
@ -148,15 +150,13 @@ class T2SEngine(T2SEngineProtocol):
logger.info(
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
)
logger.info(
f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s"
)
logger.info(f"Infer Speed: {(idx + 1) * bsz / (time.perf_counter() - t1):.2f} token/s")
infer_time = time.perf_counter() - t1
infer_speed = (idx + 1) * session.bsz / infer_time
infer_speed = (idx + 1) * bsz / infer_time
break
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
for i in range(session.bsz):
for i in range(bsz):
if not session.completed[i].item():
session.y_results[i] = session.y[
[i], session.y_len : session.y_len + idx
@ -164,7 +164,7 @@ class T2SEngine(T2SEngineProtocol):
session.completed[i] = True
logger.error("Bad Full Prediction")
infer_time = time.perf_counter() - t1
infer_speed = (idx + 1) * session.bsz / infer_time
infer_speed = (idx + 1) * bsz / infer_time
break
with torch_profiler.record("NextPos"), timer("Torch.NextPos", debug=debug):
@ -178,7 +178,13 @@ class T2SEngine(T2SEngineProtocol):
timer.summary()
timer.clear()
result = session.y_results[: request.valid_length]
return result, infer_speed, infer_time, (idx + 1) * bsz
finally:
del session
if (
request.use_cuda_graph
and self.graphcache.is_applicable
@ -200,8 +206,6 @@ class T2SEngine(T2SEngineProtocol):
case "cpu":
gc.collect(1)
return session.y_results[: request.valid_length], infer_speed, infer_time, (idx + 1) * session.bsz
def generate(self, request: T2SRequest):
try:
result, infer_speed, infer_time, total_tokens = self._handle_request(request)

View File

@ -964,6 +964,7 @@ class TTS:
"""
########## variables initialization ###########
torch.set_grad_enabled(False)
self.empty_cache()
ttft_time = time.perf_counter()
self.stop_flag: bool = False
text: str = inputs.get("text", "")
@ -1334,7 +1335,7 @@ class TTS:
gc.collect()
if self.configs.device.type == "cuda":
logger.info(str(torch.cuda.memory_allocated(self.configs.device.index) / 1024**3) + "GB")
logger.info("Curr: " + str(torch.cuda.memory_allocated(self.configs.device.index) / 1024**3) + " GB")
elif self.configs.device.type == "mps":
logger.info("Curr: " + str(torch.mps.current_allocated_memory() / 1024**3) + " GB")
logger.info("Driver: " + str(torch.mps.driver_allocated_memory() / 1024**3) + " GB")

View File

@ -54,7 +54,7 @@ Unseen speakers few-shot fine-tuning demo:
| Device | RTF | TTFT | Batch Size | Backend |
| :---------: | :----: | :----: | :--------: | :-------------------------: |
| RTX 5090 | 0.05 | 150 ms | 1 | Flash Attn Varlen CUDAGraph |
| RTX 5090 | 0.0122 | UNK | 30 | Flash Attn Varlen CUDAGraph |
| RTX 5090 | 0.0109 | UNK | 40 | Flash Attn Varlen CUDAGraph |
| RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph |
| RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph |
| Apple M4 | 0.16 | 1363ms | 1 | MLX Varlen |