Refactor T2S scheduler and inference handling to improve attention mask management and memory tracking. Update T2SRunningRequest and T2SActiveBatch classes to include optional key padding masks. Introduce new benchmarking tools for API performance and memory usage analysis, enhancing overall system efficiency.

This commit is contained in:
baicai-1145 2026-03-09 01:42:04 +08:00
parent dc37b0b9ef
commit d245eb169c
4 changed files with 1192 additions and 28 deletions

View File

@ -1781,6 +1781,7 @@ class TTS:
return audio
@torch.inference_mode()
def synthesize_audio_request_local(
self,
semantic_tokens: torch.Tensor,

View File

@ -70,7 +70,7 @@ class T2SRunningRequest:
state: T2SRequestState
y_sequence: torch.LongTensor
prefix_len: int
decode_attn_mask: torch.Tensor
decode_attn_mask: Optional[torch.Tensor]
k_cache: List[torch.Tensor]
v_cache: List[torch.Tensor]
step_idx: int
@ -93,6 +93,7 @@ class T2SActiveBatch:
y_sequences: List[torch.LongTensor]
prefix_lens: torch.LongTensor
xy_pos: torch.Tensor
key_padding_mask: torch.Tensor
prefill_attn_mask: torch.Tensor
decode_attn_mask: Optional[torch.Tensor]
k_cache: Optional[List[torch.Tensor]]
@ -110,6 +111,7 @@ def normalize_sentence(text: str, language: str) -> str:
return text
@torch.inference_mode()
def prepare_request_state(
tts: Any,
spec: SchedulerRequestSpec,
@ -212,6 +214,7 @@ def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device:
)
@torch.inference_mode()
def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch:
x_items: List[torch.Tensor] = []
y_pos_items: List[torch.Tensor] = []
@ -240,22 +243,19 @@ def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SAct
device = x_batch.device
x_lens_tensor = torch.LongTensor(x_lens).to(device)
prefix_lens_tensor = torch.LongTensor(prefix_lens).to(device)
batch_size = len(states)
src_len = max_x_len + max_prefix_len
x_padding_mask = make_pad_mask_left(x_lens_tensor, max_x_len)
y_padding_mask = make_pad_mask_left(prefix_lens_tensor, max_prefix_len)
padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1)
key_padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1).bool()
x_mask = F.pad(torch.zeros(max_x_len, max_x_len, dtype=torch.bool, device=device), (0, max_prefix_len), value=True)
y_mask = F.pad(
torch.triu(torch.ones(max_prefix_len, max_prefix_len, dtype=torch.bool, device=device), diagonal=1),
(max_x_len, 0),
value=False,
)
causal_mask = torch.cat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(batch_size, 1, 1)
padding_mask = padding_mask.view(batch_size, 1, src_len).repeat(1, src_len, 1)
attn_mask = causal_mask.logical_or(padding_mask)
attn_mask = attn_mask.unsqueeze(1).expand(-1, model.num_head, -1, -1).bool()
causal_mask = torch.cat([x_mask, y_mask], dim=0).unsqueeze(0)
attn_mask = causal_mask.logical_or(key_padding_mask.unsqueeze(1)).unsqueeze(1)
return T2SActiveBatch(
request_ids=[state.request_id for state in states],
@ -265,6 +265,7 @@ def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SAct
y_sequences=y_sequences,
prefix_lens=prefix_lens_tensor,
xy_pos=xy_pos,
key_padding_mask=key_padding_mask,
prefill_attn_mask=attn_mask,
decode_attn_mask=None,
k_cache=None,
@ -322,10 +323,11 @@ def _sample_per_request(
finish_reason = "eos_argmax"
if finish_reason is not None:
prefix_len = int(active_batch.prefix_lens[batch_index].item())
finished_items.append(
T2SFinishedItem(
request_id=state.request_id,
semantic_tokens=new_history[:-1].clone(),
semantic_tokens=new_history[prefix_len:-1].clone(),
finish_idx=step_idx,
finish_reason=finish_reason,
)
@ -346,11 +348,7 @@ def decode_one_step(
xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt(
active_batch.xy_pos, active_batch.prefill_attn_mask, None
)
active_batch.decode_attn_mask = F.pad(
active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2),
(0, 1),
value=False,
)
active_batch.decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False)
active_batch.prefill_done = True
else:
xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token(
@ -411,6 +409,18 @@ def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor:
return F.pad(mask, (pad_len, 0), value=True)
def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> torch.Tensor:
if running_request.decode_attn_mask is not None:
return running_request.decode_attn_mask
current_mask_len = running_request.k_cache[0].shape[1] + 1
return torch.zeros(
(1, 1, 1, current_mask_len),
dtype=torch.bool,
device=running_request.k_cache[0].device,
)
@torch.inference_mode()
def run_prefill_step(
model: Any,
states: Sequence[T2SRequestState],
@ -421,11 +431,9 @@ def run_prefill_step(
active_batch = build_prefill_batch(model, states)
xy_dec, k_cache, v_cache = model.t2s_transformer.process_prompt(active_batch.xy_pos, active_batch.prefill_attn_mask, None)
decode_attn_mask = F.pad(
active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2),
(0, 1),
value=False,
)
decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False)
if len(states) == 1 and not decode_attn_mask.any().item():
decode_attn_mask = None
logits = model.ar_predict_layer(xy_dec[:, -1])
running_requests: List[T2SRunningRequest] = []
@ -463,7 +471,7 @@ def run_prefill_step(
finished_items.append(
T2SFinishedItem(
request_id=state.request_id,
semantic_tokens=new_history[:-1].clone(),
semantic_tokens=new_history[prefix_len:-1].clone(),
finish_idx=0,
finish_reason=finish_reason,
)
@ -479,7 +487,11 @@ def run_prefill_step(
state=state,
y_sequence=new_history,
prefix_len=prefix_len,
decode_attn_mask=decode_attn_mask[batch_index : batch_index + 1].clone(),
decode_attn_mask=(
None
if decode_attn_mask is None
else decode_attn_mask[batch_index : batch_index + 1].clone()
),
k_cache=request_k_cache,
v_cache=request_v_cache,
step_idx=1,
@ -492,10 +504,9 @@ def run_prefill_step(
def _build_decode_batch_from_running(
model: Any,
running_requests: Sequence[T2SRunningRequest],
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], torch.Tensor]:
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], Optional[torch.Tensor]]:
xy_pos = build_next_xy_pos(model, [item.y_sequence for item in running_requests])
max_kv_len = max(item.k_cache[0].shape[1] for item in running_requests)
max_mask_len = max(item.decode_attn_mask.shape[-1] for item in running_requests)
num_layers = len(running_requests[0].k_cache)
batched_k_cache: List[torch.Tensor] = []
@ -508,13 +519,19 @@ def _build_decode_batch_from_running(
torch.cat([_pad_cache_left(item.v_cache[layer_index], max_kv_len) for item in running_requests], dim=0)
)
batched_decode_attn_mask = torch.cat(
[_pad_decode_mask_left(item.decode_attn_mask, max_mask_len) for item in running_requests],
dim=0,
)
if all(item.decode_attn_mask is None for item in running_requests):
batched_decode_attn_mask = None
else:
materialized_masks = [_materialize_decode_mask_for_request(item) for item in running_requests]
max_mask_len = max(mask.shape[-1] for mask in materialized_masks)
batched_decode_attn_mask = torch.cat(
[_pad_decode_mask_left(mask, max_mask_len) for mask in materialized_masks],
dim=0,
)
return xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask
@torch.inference_mode()
def run_decode_step_for_running(
model: Any,
running_requests: Sequence[T2SRunningRequest],
@ -568,7 +585,7 @@ def run_decode_step_for_running(
finished_items.append(
T2SFinishedItem(
request_id=running_request.state.request_id,
semantic_tokens=new_history[:-1].clone(),
semantic_tokens=new_history[running_request.prefix_len:-1].clone(),
finish_idx=current_idx,
finish_reason=finish_reason,
)
@ -578,12 +595,20 @@ def run_decode_step_for_running(
real_next_kv_len = running_request.k_cache[0].shape[1] + 1
request_k_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_k_cache]
request_v_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_v_cache]
if batched_decode_attn_mask is None:
next_decode_attn_mask = None
else:
current_decode_mask_len = running_request.k_cache[0].shape[1] + 1
current_decode_attn_mask = batched_decode_attn_mask[
batch_index : batch_index + 1, :, :, -current_decode_mask_len:
]
next_decode_attn_mask = F.pad(current_decode_attn_mask, (0, 1), value=False)
next_running.append(
T2SRunningRequest(
state=running_request.state,
y_sequence=new_history,
prefix_len=running_request.prefix_len,
decode_attn_mask=F.pad(running_request.decode_attn_mask, (0, 1), value=False),
decode_attn_mask=next_decode_attn_mask,
k_cache=request_k_cache,
v_cache=request_v_cache,
step_idx=current_idx + 1,
@ -593,6 +618,7 @@ def run_decode_step_for_running(
return next_running, finished_items
@torch.inference_mode()
def run_scheduler_continuous(
model: Any,
states: Sequence[T2SRequestState],

View File

@ -0,0 +1,250 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import argparse
import asyncio
import json
import subprocess
import threading
import time
import wave
from pathlib import Path
from typing import Any, Dict, List, Optional
import httpx
ROOT_DIR = Path(__file__).resolve().parents[1]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Benchmark api_v3 /tts_scheduler_submit concurrency and GPU memory.")
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:9880")
parser.add_argument("--endpoint", type=str, default="/tts_scheduler_submit")
parser.add_argument("--concurrency", type=int, required=True)
parser.add_argument("--timeout-sec", type=float, default=120.0)
parser.add_argument("--server-pid", type=int, default=None)
parser.add_argument("--poll-interval-sec", type=float, default=0.1)
parser.add_argument("--text-lang", type=str, default="zh")
parser.add_argument("--prompt-lang", type=str, default="zh")
parser.add_argument("--media-type", type=str, default="wav")
parser.add_argument("--top-k", type=int, default=15)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--repetition-penalty", type=float, default=1.35)
parser.add_argument("--sample-steps", type=int, default=32)
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_cn.txt")
parser.add_argument("--wav-dir", type=Path, default=ROOT_DIR / "testwav")
parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/api_v3_bench")
return parser.parse_args()
def load_requests(args: argparse.Namespace) -> List[Dict[str, Any]]:
wav_paths_all = sorted(args.wav_dir.glob("*.wav"))
wav_paths: List[Path] = []
for wav_path in wav_paths_all:
with wave.open(str(wav_path), "rb") as handle:
duration = handle.getnframes() / float(handle.getframerate())
if 3.0 <= duration <= 10.0:
wav_paths.append(wav_path)
if not wav_paths:
raise FileNotFoundError(f"没有找到 3-10 秒合法 wav: {args.wav_dir}")
text_lines = [line.strip() for line in args.text_file.read_text(encoding="utf-8").splitlines() if line.strip()]
if not text_lines:
raise ValueError(f"没有找到有效文本行: {args.text_file}")
requests: List[Dict[str, Any]] = []
for index in range(args.concurrency):
wav_path = wav_paths[index % len(wav_paths)]
lab_path = wav_path.with_suffix(".lab")
if not lab_path.exists():
raise FileNotFoundError(f"缺少参考文本: {lab_path}")
requests.append(
{
"request_id": f"bench_{args.concurrency:03d}_{index:03d}",
"text": text_lines[index % len(text_lines)],
"text_lang": args.text_lang,
"ref_audio_path": str(wav_path),
"prompt_lang": args.prompt_lang,
"prompt_text": lab_path.read_text(encoding="utf-8").strip(),
"top_k": int(args.top_k),
"top_p": float(args.top_p),
"temperature": float(args.temperature),
"repetition_penalty": float(args.repetition_penalty),
"sample_steps": int(args.sample_steps),
"media_type": args.media_type,
"timeout_sec": float(args.timeout_sec),
}
)
return requests
class GpuMemoryPoller:
def __init__(self, server_pid: Optional[int], interval_sec: float):
self.server_pid = server_pid
self.interval_sec = interval_sec
self._stop = threading.Event()
self.samples: List[Dict[str, Any]] = []
self.thread: Optional[threading.Thread] = None
def _query_memory_mb(self) -> Optional[int]:
try:
result = subprocess.run(
[
"nvidia-smi",
"--query-compute-apps=pid,used_gpu_memory",
"--format=csv,noheader,nounits",
],
check=True,
capture_output=True,
text=True,
)
except Exception:
return None
total = 0
found = False
for line in result.stdout.splitlines():
line = line.strip()
if not line:
continue
parts = [item.strip() for item in line.split(",")]
if len(parts) != 2:
continue
try:
pid = int(parts[0])
used_mb = int(parts[1])
except ValueError:
continue
if self.server_pid is None or pid == self.server_pid:
total += used_mb
found = True
if self.server_pid is None:
return total
return total if found else 0
def _run(self) -> None:
while not self._stop.is_set():
used_mb = self._query_memory_mb()
self.samples.append({"ts": time.time(), "used_mb": used_mb})
self._stop.wait(self.interval_sec)
def start(self) -> None:
self.thread = threading.Thread(target=self._run, daemon=True)
self.thread.start()
def stop(self) -> None:
self._stop.set()
if self.thread is not None:
self.thread.join(timeout=2.0)
def summary(self) -> Dict[str, Any]:
valid = [item for item in self.samples if item["used_mb"] is not None]
peak = max(valid, key=lambda item: item["used_mb"]) if valid else None
first = valid[0] if valid else None
last = valid[-1] if valid else None
return {
"server_pid": self.server_pid,
"sample_count": int(len(self.samples)),
"start_used_mb": None if first is None else int(first["used_mb"]),
"peak_used_mb": None if peak is None else int(peak["used_mb"]),
"peak_delta_mb": None if peak is None or first is None else int(peak["used_mb"] - first["used_mb"]),
"end_used_mb": None if last is None else int(last["used_mb"]),
"peak_ts": None if peak is None else float(peak["ts"]),
"samples": self.samples,
}
async def submit_one(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
started = time.perf_counter()
try:
response = await client.post(url, json=payload)
elapsed_ms = (time.perf_counter() - started) * 1000.0
item = {
"request_id": payload["request_id"],
"status_code": int(response.status_code),
"elapsed_ms": float(elapsed_ms),
"content_type": response.headers.get("content-type"),
"audio_bytes": int(len(response.content)),
"headers": {key: value for key, value in response.headers.items() if key.lower().startswith("x-")},
}
if response.status_code != 200:
try:
item["error_body"] = response.json()
except Exception:
item["error_body"] = response.text
return item
except Exception as exc:
return {
"request_id": payload["request_id"],
"status_code": -1,
"elapsed_ms": float((time.perf_counter() - started) * 1000.0),
"exception": repr(exc),
}
async def run_benchmark(args: argparse.Namespace) -> Dict[str, Any]:
payloads = load_requests(args)
url = args.base_url.rstrip("/") + args.endpoint
poller = GpuMemoryPoller(server_pid=args.server_pid, interval_sec=args.poll_interval_sec)
limits = httpx.Limits(max_connections=args.concurrency, max_keepalive_connections=args.concurrency)
timeout = httpx.Timeout(connect=10.0, read=args.timeout_sec + 10.0, write=10.0, pool=10.0)
started = time.perf_counter()
poller.start()
try:
async with httpx.AsyncClient(limits=limits, timeout=timeout) as client:
results = await asyncio.gather(*[submit_one(client, url, payload) for payload in payloads])
finally:
poller.stop()
wall_ms = (time.perf_counter() - started) * 1000.0
ok_results = [item for item in results if item["status_code"] == 200]
failed_results = [item for item in results if item["status_code"] != 200]
request_total_ms = []
worker_total_ms = []
for item in ok_results:
headers = item.get("headers", {})
if "x-request-total-ms" in headers:
request_total_ms.append(float(headers["x-request-total-ms"]))
if "x-worker-total-ms" in headers:
worker_total_ms.append(float(headers["x-worker-total-ms"]))
return {
"concurrency": int(args.concurrency),
"server_pid": args.server_pid,
"request_count": int(len(payloads)),
"wall_ms": float(wall_ms),
"success_count": int(len(ok_results)),
"failure_count": int(len(failed_results)),
"request_total_ms_avg": float(sum(request_total_ms) / len(request_total_ms)) if request_total_ms else None,
"request_total_ms_max": float(max(request_total_ms)) if request_total_ms else None,
"worker_total_ms_avg": float(sum(worker_total_ms) / len(worker_total_ms)) if worker_total_ms else None,
"worker_total_ms_max": float(max(worker_total_ms)) if worker_total_ms else None,
"gpu_memory": poller.summary(),
"results": results,
}
def main() -> None:
args = parse_args()
output_dir = args.output_dir / f"concurrency_{args.concurrency:02d}"
output_dir.mkdir(parents=True, exist_ok=True)
summary = asyncio.run(run_benchmark(args))
summary_path = output_dir / "summary.json"
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps({
"concurrency": summary["concurrency"],
"success_count": summary["success_count"],
"failure_count": summary["failure_count"],
"wall_ms": summary["wall_ms"],
"gpu_peak_used_mb": summary["gpu_memory"]["peak_used_mb"],
"request_total_ms_avg": summary["request_total_ms_avg"],
"request_total_ms_max": summary["request_total_ms_max"],
"summary_path": str(summary_path),
}, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,887 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import argparse
import gc
import contextlib
import json
import random
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
ROOT_DIR = Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS"
if str(gpt_sovits_dir) not in sys.path:
sys.path.append(str(gpt_sovits_dir))
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402
SchedulerRequestSpec,
T2SRequestState,
T2SRunningRequest,
_build_decode_batch_from_running,
build_prefill_batch,
prepare_request_state,
run_decode_step_for_running,
run_prefill_step,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Break down T2S CUDA memory by stage and tensor groups.")
parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml")
parser.add_argument("--request-manifest", type=Path, default=None)
parser.add_argument("--scenario", type=str, default="auto4", choices=["auto4", "single"])
parser.add_argument("--auto-count", type=int, default=4)
parser.add_argument("--auto-wav-dir", type=Path, default=ROOT_DIR / "testwav")
parser.add_argument("--auto-text-file", type=Path, default=ROOT_DIR / "test_cn.txt")
parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav")
parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。")
parser.add_argument("--prompt-lang", type=str, default="zh")
parser.add_argument("--text", type=str, default=None)
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt")
parser.add_argument("--text-lang", type=str, default="zh")
parser.add_argument("--top-k", type=int, default=15)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--repetition-penalty", type=float, default=1.35)
parser.add_argument("--early-stop-num", type=int, default=-1)
parser.add_argument("--max-steps", type=int, default=1500)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--warmup", action="store_true", default=False)
parser.add_argument("--worker-rounds", type=int, default=1)
parser.add_argument("--worker-grad-mode", type=str, default="default", choices=["default", "inference_mode"])
parser.add_argument("--compare-worker-grad-modes", action="store_true", default=False)
parser.add_argument(
"--output-dir",
type=Path,
default=ROOT_DIR / "TEMP/t2s_memory_breakdown/run1",
)
return parser.parse_args()
def set_seed(seed: int, use_cuda: bool) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if use_cuda and torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def _sync_device(device: Any) -> None:
try:
device_str = str(device)
if device_str.startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize(device)
elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"):
torch.mps.synchronize()
except Exception:
pass
def bytes_to_mb(num_bytes: int) -> float:
return float(num_bytes) / (1024.0 * 1024.0)
def tensor_nbytes(tensor: Optional[torch.Tensor]) -> int:
if tensor is None:
return 0
return int(tensor.numel() * tensor.element_size())
def tensor_list_nbytes(items: Sequence[torch.Tensor]) -> int:
return int(sum(tensor_nbytes(item) for item in items))
def model_nbytes(module: torch.nn.Module) -> int:
total = 0
for parameter in module.parameters():
total += tensor_nbytes(parameter)
for buffer in module.buffers():
total += tensor_nbytes(buffer)
return int(total)
def build_module_weight_summary(tts: TTS) -> Dict[str, Any]:
modules = {
"t2s_model": tts.t2s_model,
"t2s_core": tts.t2s_model.model if tts.t2s_model is not None else None,
"vits_model": tts.vits_model,
"bert_model": tts.bert_model,
"cnhuhbert_model": tts.cnhuhbert_model,
"vocoder": tts.vocoder,
"sv_model": tts.sv_model,
}
by_module = {}
total_bytes = 0
for name, module in modules.items():
module_bytes = model_nbytes(module) if module is not None else 0
by_module[name] = {
"bytes": int(module_bytes),
"mb": bytes_to_mb(module_bytes),
}
total_bytes += module_bytes
return {
"by_module": by_module,
"total_bytes": int(total_bytes),
"total_mb": bytes_to_mb(total_bytes),
}
def snapshot_live_cuda_tensors(top_k: int = 40) -> Dict[str, Any]:
storages: Dict[int, Dict[str, Any]] = {}
tensor_views: List[Dict[str, Any]] = []
for obj in gc.get_objects():
try:
tensor = None
if torch.is_tensor(obj):
tensor = obj
elif hasattr(obj, "data") and torch.is_tensor(obj.data):
tensor = obj.data
if tensor is None or not tensor.is_cuda:
continue
storage = tensor.untyped_storage()
storage_ptr = int(storage.data_ptr())
if storage_ptr not in storages:
storages[storage_ptr] = {
"storage_ptr": storage_ptr,
"storage_bytes": int(storage.nbytes()),
"dtype": str(tensor.dtype),
"shape": list(tensor.shape),
"device": str(tensor.device),
}
tensor_views.append(
{
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
"bytes": tensor_nbytes(tensor),
"device": str(tensor.device),
}
)
except Exception:
continue
storage_list = sorted(storages.values(), key=lambda item: item["storage_bytes"], reverse=True)
tensor_views.sort(key=lambda item: item["bytes"], reverse=True)
return {
"unique_storage_count": int(len(storage_list)),
"unique_storage_total_bytes": int(sum(item["storage_bytes"] for item in storage_list)),
"unique_storage_total_mb": bytes_to_mb(sum(item["storage_bytes"] for item in storage_list)),
"top_storages": storage_list[:top_k],
"top_tensor_views": tensor_views[:top_k],
}
def build_single_spec(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8").strip()
return [
SchedulerRequestSpec(
request_id="req_000",
ref_audio_path=args.ref_audio,
prompt_text=args.prompt_text,
prompt_lang=args.prompt_lang,
text=text,
text_lang=args.text_lang,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
early_stop_num=args.early_stop_num,
ready_step=0,
)
]
def build_auto_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
wav_paths = sorted(args.auto_wav_dir.glob("*.wav"))[: args.auto_count]
if len(wav_paths) < args.auto_count:
raise ValueError(f"auto wav count不足目录 {args.auto_wav_dir} 只有 {len(wav_paths)} 条 wav")
text_lines = [line.strip() for line in args.auto_text_file.read_text(encoding="utf-8").splitlines() if line.strip()]
if len(text_lines) < args.auto_count:
raise ValueError(f"auto text lines不足文件 {args.auto_text_file} 只有 {len(text_lines)} 行有效文本")
specs: List[SchedulerRequestSpec] = []
for index, wav_path in enumerate(wav_paths):
lab_path = wav_path.with_suffix(".lab")
if not lab_path.exists():
raise FileNotFoundError(f"找不到参考文本 {lab_path}")
specs.append(
SchedulerRequestSpec(
request_id=f"req_{index:03d}",
ref_audio_path=wav_path,
prompt_text=lab_path.read_text(encoding="utf-8").strip(),
prompt_lang="zh",
text=text_lines[index],
text_lang=args.text_lang,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
early_stop_num=args.early_stop_num,
ready_step=0,
)
)
return specs
def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
if args.request_manifest is not None:
payload = json.loads(args.request_manifest.read_text(encoding="utf-8"))
raw_requests = payload["requests"] if isinstance(payload, dict) else payload
specs: List[SchedulerRequestSpec] = []
for index, item in enumerate(raw_requests):
text = item.get("text")
text_file = item.get("text_file")
if text is None and text_file is None:
raise ValueError(f"request[{index}] must provide text or text_file")
if text is None:
text = Path(text_file).read_text(encoding="utf-8").strip()
specs.append(
SchedulerRequestSpec(
request_id=item.get("request_id", f"req_{index:03d}"),
ref_audio_path=Path(item["ref_audio_path"]),
prompt_text=item["prompt_text"],
prompt_lang=item.get("prompt_lang", "zh"),
text=text,
text_lang=item.get("text_lang", "zh"),
top_k=int(item.get("top_k", args.top_k)),
top_p=float(item.get("top_p", args.top_p)),
temperature=float(item.get("temperature", args.temperature)),
repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)),
early_stop_num=int(item.get("early_stop_num", args.early_stop_num)),
ready_step=int(item.get("ready_step", 0)),
)
)
return specs
if args.scenario == "single":
return build_single_spec(args)
return build_auto_specs(args)
def load_pipeline(config_path: Path) -> TTS:
tts_config = TTS_Config(str(config_path))
print(tts_config)
return TTS(tts_config)
def cuda_mem_snapshot(device: Any) -> Dict[str, float]:
if not (str(device).startswith("cuda") and torch.cuda.is_available()):
return {
"allocated_mb": 0.0,
"reserved_mb": 0.0,
"max_allocated_mb": 0.0,
"max_reserved_mb": 0.0,
}
_sync_device(device)
return {
"allocated_mb": bytes_to_mb(torch.cuda.memory_allocated(device)),
"reserved_mb": bytes_to_mb(torch.cuda.memory_reserved(device)),
"max_allocated_mb": bytes_to_mb(torch.cuda.max_memory_allocated(device)),
"max_reserved_mb": bytes_to_mb(torch.cuda.max_memory_reserved(device)),
}
def stage_run(device: Any, fn) -> Tuple[Any, Dict[str, float]]:
if str(device).startswith("cuda") and torch.cuda.is_available():
gc.collect()
_sync_device(device)
torch.cuda.reset_peak_memory_stats(device)
before = cuda_mem_snapshot(device)
started = time.perf_counter()
result = fn()
_sync_device(device)
elapsed_ms = (time.perf_counter() - started) * 1000.0
after = cuda_mem_snapshot(device)
after["elapsed_ms"] = float(elapsed_ms)
after["delta_allocated_mb"] = float(after["allocated_mb"] - before["allocated_mb"])
after["delta_reserved_mb"] = float(after["reserved_mb"] - before["reserved_mb"])
after["stage_peak_over_before_mb"] = float(max(after["max_allocated_mb"] - before["allocated_mb"], 0.0))
return result, after
class GlobalPeakRecorder:
def __init__(self, device: Any):
self.device = device
self.checkpoints: List[Dict[str, Any]] = []
if str(device).startswith("cuda") and torch.cuda.is_available():
gc.collect()
_sync_device(device)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
def record(self, label: str, **extra: Any) -> None:
snapshot = cuda_mem_snapshot(self.device)
snapshot["label"] = label
snapshot.update(extra)
self.checkpoints.append(snapshot)
def summary(self) -> Dict[str, Any]:
peak = max(self.checkpoints, key=lambda item: item["max_allocated_mb"]) if self.checkpoints else None
return {
"peak_allocated_mb": 0.0 if peak is None else float(peak["max_allocated_mb"]),
"peak_reserved_mb": 0.0 if peak is None else float(peak["max_reserved_mb"]),
"peak_label": None if peak is None else peak["label"],
"checkpoints": self.checkpoints,
}
def summarise_state_tensors(states: Sequence[T2SRequestState]) -> Dict[str, Any]:
per_request = []
total = {
"phones_bytes": 0,
"prompt_phones_bytes": 0,
"all_phones_bytes": 0,
"all_bert_features_bytes": 0,
"prompt_semantic_bytes": 0,
"refer_spec_bytes": 0,
"raw_audio_bytes": 0,
"audio_16k_bytes": 0,
}
for state in states:
spec_audio, audio_16k = state.refer_spec
item = {
"request_id": state.request_id,
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
"phones_len": int(state.phones.shape[0]),
"all_phones_len": int(state.all_phones.shape[0]),
"bert_frames": int(state.all_bert_features.shape[-1]),
"phones_bytes": tensor_nbytes(state.phones),
"prompt_phones_bytes": tensor_nbytes(state.prompt_phones),
"all_phones_bytes": tensor_nbytes(state.all_phones),
"all_bert_features_bytes": tensor_nbytes(state.all_bert_features),
"prompt_semantic_bytes": tensor_nbytes(state.prompt_semantic),
"refer_spec_bytes": tensor_nbytes(spec_audio),
"audio_16k_bytes": tensor_nbytes(audio_16k),
"raw_audio_bytes": tensor_nbytes(state.raw_audio),
}
for key in total:
total[key] += int(item[key])
per_request.append(item)
total["total_bytes"] = int(sum(total.values()))
total["total_mb"] = bytes_to_mb(total["total_bytes"])
return {"per_request": per_request, "total": total}
def summarise_prefill_batch(active_batch: Any) -> Dict[str, Any]:
y_sequence_bytes = int(sum(tensor_nbytes(item) for item in active_batch.y_sequences))
fields = {
"x_bytes": tensor_nbytes(active_batch.x),
"x_lens_bytes": tensor_nbytes(active_batch.x_lens),
"prefix_lens_bytes": tensor_nbytes(active_batch.prefix_lens),
"xy_pos_bytes": tensor_nbytes(active_batch.xy_pos),
"key_padding_mask_bytes": tensor_nbytes(active_batch.key_padding_mask),
"prefill_attn_mask_bytes": tensor_nbytes(active_batch.prefill_attn_mask),
"y_sequence_bytes": y_sequence_bytes,
}
fields["total_bytes"] = int(sum(fields.values()))
fields["total_mb"] = bytes_to_mb(fields["total_bytes"])
fields["batch_size"] = int(len(active_batch.states))
fields["max_x_len"] = int(active_batch.x.shape[1])
fields["src_len"] = int(active_batch.xy_pos.shape[1])
fields["prefill_attn_mask_shape"] = list(active_batch.prefill_attn_mask.shape)
return fields
def summarise_running_requests(running_requests: Sequence[T2SRunningRequest]) -> Dict[str, Any]:
per_request = []
total_private_k_bytes = 0
total_private_v_bytes = 0
total_decode_mask_bytes = 0
total_y_sequence_bytes = 0
for item in running_requests:
k_bytes = tensor_list_nbytes(item.k_cache)
v_bytes = tensor_list_nbytes(item.v_cache)
mask_bytes = tensor_nbytes(item.decode_attn_mask)
y_bytes = tensor_nbytes(item.y_sequence)
total_private_k_bytes += k_bytes
total_private_v_bytes += v_bytes
total_decode_mask_bytes += mask_bytes
total_y_sequence_bytes += y_bytes
per_request.append(
{
"request_id": item.state.request_id,
"step_idx": int(item.step_idx),
"prefix_len": int(item.prefix_len),
"history_len": int(item.y_sequence.shape[0]),
"kv_len": int(item.k_cache[0].shape[1]),
"k_cache_bytes": k_bytes,
"v_cache_bytes": v_bytes,
"decode_mask_bytes": mask_bytes,
"y_sequence_bytes": y_bytes,
}
)
total_bytes = total_private_k_bytes + total_private_v_bytes + total_decode_mask_bytes + total_y_sequence_bytes
return {
"per_request": per_request,
"totals": {
"private_k_cache_bytes": int(total_private_k_bytes),
"private_v_cache_bytes": int(total_private_v_bytes),
"private_kv_cache_bytes": int(total_private_k_bytes + total_private_v_bytes),
"decode_mask_bytes": int(total_decode_mask_bytes),
"y_sequence_bytes": int(total_y_sequence_bytes),
"total_bytes": int(total_bytes),
"total_mb": bytes_to_mb(total_bytes),
},
}
def summarise_decode_batch(
xy_pos: torch.Tensor,
batched_k_cache: Sequence[torch.Tensor],
batched_v_cache: Sequence[torch.Tensor],
batched_decode_attn_mask: Optional[torch.Tensor],
running_requests: Sequence[T2SRunningRequest],
) -> Dict[str, Any]:
private_k_bytes = int(sum(tensor_list_nbytes(item.k_cache) for item in running_requests))
private_v_bytes = int(sum(tensor_list_nbytes(item.v_cache) for item in running_requests))
batched_k_bytes = tensor_list_nbytes(batched_k_cache)
batched_v_bytes = tensor_list_nbytes(batched_v_cache)
batched_mask_bytes = tensor_nbytes(batched_decode_attn_mask)
xy_pos_bytes = tensor_nbytes(xy_pos)
total_bytes = batched_k_bytes + batched_v_bytes + batched_mask_bytes + xy_pos_bytes
return {
"batch_size": int(len(running_requests)),
"xy_pos_bytes": int(xy_pos_bytes),
"batched_k_cache_bytes": int(batched_k_bytes),
"batched_v_cache_bytes": int(batched_v_bytes),
"batched_kv_cache_bytes": int(batched_k_bytes + batched_v_bytes),
"batched_decode_mask_bytes": int(batched_mask_bytes),
"private_kv_cache_bytes_reference": int(private_k_bytes + private_v_bytes),
"kv_padding_overhead_bytes": int((batched_k_bytes + batched_v_bytes) - (private_k_bytes + private_v_bytes)),
"total_bytes": int(total_bytes),
"total_mb": bytes_to_mb(total_bytes),
"xy_pos_shape": list(xy_pos.shape),
"batched_decode_mask_shape": None if batched_decode_attn_mask is None else list(batched_decode_attn_mask.shape),
"layer_k_cache_shape": list(batched_k_cache[0].shape),
}
def summarise_decode_outputs(
xy_dec: torch.Tensor,
next_k_cache: Sequence[torch.Tensor],
next_v_cache: Sequence[torch.Tensor],
) -> Dict[str, Any]:
xy_dec_bytes = tensor_nbytes(xy_dec)
next_k_bytes = tensor_list_nbytes(next_k_cache)
next_v_bytes = tensor_list_nbytes(next_v_cache)
total_bytes = xy_dec_bytes + next_k_bytes + next_v_bytes
return {
"xy_dec_bytes": int(xy_dec_bytes),
"next_k_cache_bytes": int(next_k_bytes),
"next_v_cache_bytes": int(next_v_bytes),
"next_kv_cache_bytes": int(next_k_bytes + next_v_bytes),
"total_bytes": int(total_bytes),
"total_mb": bytes_to_mb(total_bytes),
"xy_dec_shape": list(xy_dec.shape),
"layer_next_k_cache_shape": list(next_k_cache[0].shape),
}
def top_rankings(summary: Dict[str, Any]) -> List[Dict[str, Any]]:
ranking = [
("request_state_total", summary["prepare_stage"]["request_state"]["total"]["total_bytes"]),
("prefill_batch_total", summary["prefill_batch"]["tensor_bytes"]["total_bytes"]),
("running_private_kv", summary["prefill_step"]["running_requests"]["totals"]["private_kv_cache_bytes"]),
("decode_batched_kv", summary["decode_batch"]["tensor_bytes"]["batched_kv_cache_bytes"]),
("decode_kv_padding_overhead", summary["decode_batch"]["tensor_bytes"]["kv_padding_overhead_bytes"]),
("decode_outputs_next_kv", summary["decode_outputs"]["tensor_bytes"]["next_kv_cache_bytes"]),
("prefill_attn_mask", summary["prefill_batch"]["tensor_bytes"]["prefill_attn_mask_bytes"]),
]
ranking.sort(key=lambda item: item[1], reverse=True)
return [{"name": name, "bytes": int(value), "mb": bytes_to_mb(int(value))} for name, value in ranking]
def synthesize_finished_item(tts: TTS, state: T2SRequestState, semantic_tokens: torch.Tensor) -> Tuple[int, np.ndarray]:
semantic_tokens = semantic_tokens.unsqueeze(0).unsqueeze(0).to(tts.configs.device)
phones = state.phones.unsqueeze(0).to(tts.configs.device)
audio_fragment = tts.synthesize_audio_request_local(
semantic_tokens=semantic_tokens,
phones=phones,
prompt_semantic=state.prompt_semantic,
prompt_phones=state.prompt_phones,
refer_spec=state.refer_spec,
raw_audio=state.raw_audio,
raw_sr=state.raw_sr,
speed=1.0,
sample_steps=32,
)
output_sr = tts.configs.sampling_rate if not tts.configs.use_vocoder else tts.vocoder_configs["sr"]
return tts.audio_postprocess(
audio=[[audio_fragment]],
sr=int(output_sr),
batch_index_list=None,
speed_factor=1.0,
split_bucket=False,
fragment_interval=0.0,
super_sampling=False,
)
def simulate_worker_end_to_end(
tts: TTS,
specs: Sequence[SchedulerRequestSpec],
max_steps: int,
rounds: int,
grad_mode: str = "default",
) -> Dict[str, Any]:
device = tts.configs.device
recorder = GlobalPeakRecorder(device)
recorder.record("after_model_load")
state_map: Dict[str, T2SRequestState] = {}
per_round: List[Dict[str, Any]] = []
for round_index in range(rounds):
grad_context = torch.inference_mode if grad_mode == "inference_mode" else contextlib.nullcontext
with grad_context():
states = [prepare_request_state(tts, spec) for spec in specs]
state_map = {state.request_id: state for state in states}
recorder.record(
"after_prepare_states",
round_index=int(round_index),
request_count=int(len(states)),
grad_mode=grad_mode,
)
pending = list(states)
running_requests: List[T2SRunningRequest] = []
round_events: List[Dict[str, Any]] = []
current_tick = 0
while pending or running_requests:
admitted = pending
pending = []
if admitted:
recorder.record(
"before_prefill",
round_index=int(round_index),
tick=int(current_tick),
admitted_count=int(len(admitted)),
running_count=int(len(running_requests)),
grad_mode=grad_mode,
)
with grad_context():
admitted_running, admitted_finished = run_prefill_step(tts.t2s_model.model, admitted, max_steps=max_steps)
recorder.record(
"after_prefill",
round_index=int(round_index),
tick=int(current_tick),
admitted_running_count=int(len(admitted_running)),
admitted_finished_count=int(len(admitted_finished)),
running_count=int(len(running_requests)),
grad_mode=grad_mode,
)
round_events.append(
{
"tick": int(current_tick),
"event": "prefill",
"admitted_count": int(len(admitted)),
"admitted_running_count": int(len(admitted_running)),
"admitted_finished_count": int(len(admitted_finished)),
}
)
for item in admitted_finished:
recorder.record(
"before_synth_prefill_finished",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_request_id=item.request_id,
semantic_len=int(item.semantic_tokens.shape[0]),
grad_mode=grad_mode,
)
with grad_context():
sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens)
recorder.record(
"after_synth_prefill_finished",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_request_id=item.request_id,
sample_rate=int(sample_rate),
audio_samples=int(audio_data.shape[0]),
grad_mode=grad_mode,
)
running_requests.extend(admitted_running)
recorder.record(
"after_extend_running",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
grad_mode=grad_mode,
)
if running_requests:
recorder.record(
"before_decode",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
grad_mode=grad_mode,
)
with grad_context():
running_requests, step_finished = run_decode_step_for_running(
tts.t2s_model.model,
running_requests,
max_steps=max_steps,
)
recorder.record(
"after_decode",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_count=int(len(step_finished)),
grad_mode=grad_mode,
)
round_events.append(
{
"tick": int(current_tick),
"event": "decode",
"running_count_after_decode": int(len(running_requests)),
"finished_count": int(len(step_finished)),
}
)
for item in step_finished:
recorder.record(
"before_synth_decode_finished",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_request_id=item.request_id,
semantic_len=int(item.semantic_tokens.shape[0]),
grad_mode=grad_mode,
)
with grad_context():
sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens)
recorder.record(
"after_synth_decode_finished",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_request_id=item.request_id,
sample_rate=int(sample_rate),
audio_samples=int(audio_data.shape[0]),
grad_mode=grad_mode,
)
current_tick += 1
recorder.record(
"after_round_complete",
round_index=int(round_index),
running_count=0,
grad_mode=grad_mode,
)
per_round.append(
{
"round_index": int(round_index),
"events": round_events,
}
)
return {
"grad_mode": grad_mode,
"rounds": per_round,
"timeline": recorder.summary(),
}
def main() -> None:
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
tts = load_pipeline(args.config)
model = tts.t2s_model.model
device = tts.configs.device
use_cuda = str(device).startswith("cuda") and torch.cuda.is_available()
set_seed(args.seed, use_cuda)
specs = load_request_specs(args)
if args.early_stop_num == -1:
for spec in specs:
spec.early_stop_num = int(tts.configs.hz * tts.configs.max_sec)
if args.warmup and specs:
warmup_spec = specs[:1]
_ = [prepare_request_state(tts, spec) for spec in warmup_spec]
gc.collect()
if use_cuda:
torch.cuda.empty_cache()
_sync_device(device)
states, prepare_mem = stage_run(device, lambda: [prepare_request_state(tts, spec) for spec in specs])
request_state_summary = summarise_state_tensors(states)
active_batch, prefill_batch_mem = stage_run(device, lambda: build_prefill_batch(model, states))
prefill_batch_tensor_summary = summarise_prefill_batch(active_batch)
prefill_result, prefill_step_mem = stage_run(device, lambda: run_prefill_step(model, states, max_steps=args.max_steps))
running_requests, finished_items = prefill_result
running_requests_summary = summarise_running_requests(running_requests)
finished_after_prefill_summary = [
{
"request_id": item.request_id,
"finish_idx": int(item.finish_idx),
"finish_reason": item.finish_reason,
"semantic_len": int(item.semantic_tokens.shape[0]),
}
for item in finished_items
]
if not running_requests:
raise RuntimeError(f"prefill 后没有 running requests全部在首步结束: {[item.request_id for item in finished_items]}")
decode_batch_result, decode_batch_mem = stage_run(
device,
lambda: _build_decode_batch_from_running(model, running_requests),
)
xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = decode_batch_result
decode_batch_tensor_summary = summarise_decode_batch(
xy_pos,
batched_k_cache,
batched_v_cache,
batched_decode_attn_mask,
running_requests,
)
decode_out_result, decode_step_mem = stage_run(
device,
lambda: model.t2s_transformer.decode_next_token(
xy_pos,
batched_k_cache,
batched_v_cache,
batched_decode_attn_mask,
),
)
xy_dec, next_k_cache, next_v_cache = decode_out_result
decode_output_tensor_summary = summarise_decode_outputs(xy_dec, next_k_cache, next_v_cache)
del active_batch
del running_requests
del finished_items
del xy_pos
del batched_k_cache
del batched_v_cache
del batched_decode_attn_mask
del xy_dec
del next_k_cache
del next_v_cache
gc.collect()
if use_cuda:
_sync_device(device)
torch.cuda.empty_cache()
end_to_end_worker = simulate_worker_end_to_end(
tts=tts,
specs=specs,
max_steps=args.max_steps,
rounds=args.worker_rounds,
grad_mode=args.worker_grad_mode,
)
live_cuda_tensors_after_worker = snapshot_live_cuda_tensors()
worker_inference_mode = None
if args.compare_worker_grad_modes:
gc.collect()
if use_cuda:
_sync_device(device)
torch.cuda.empty_cache()
worker_inference_mode = simulate_worker_end_to_end(
tts=tts,
specs=specs,
max_steps=args.max_steps,
rounds=args.worker_rounds,
grad_mode="inference_mode",
)
summary = {
"meta": {
"scenario": args.scenario if args.request_manifest is None else "manifest",
"seed": int(args.seed),
"device": str(device),
"dtype": str(next(model.parameters()).dtype),
"request_count": int(len(specs)),
"num_layers": int(model.num_layers),
"num_heads": int(model.num_head),
"model_dim": int(model.model_dim),
"model_weights_mb": bytes_to_mb(model_nbytes(model)),
},
"loaded_module_weights": build_module_weight_summary(tts),
"requests": [
{
"request_id": spec.request_id,
"ref_audio_path": str(spec.ref_audio_path),
"prompt_text": spec.prompt_text,
"text": spec.text,
}
for spec in specs
],
"prepare_stage": {
"memory": prepare_mem,
"request_state": request_state_summary,
},
"prefill_batch": {
"memory": prefill_batch_mem,
"tensor_bytes": prefill_batch_tensor_summary,
},
"prefill_step": {
"memory": prefill_step_mem,
"running_requests": running_requests_summary,
"finished_after_prefill": finished_after_prefill_summary,
},
"decode_batch": {
"memory": decode_batch_mem,
"tensor_bytes": decode_batch_tensor_summary,
},
"decode_outputs": {
"memory": decode_step_mem,
"tensor_bytes": decode_output_tensor_summary,
},
"end_to_end_worker": end_to_end_worker,
"live_cuda_tensors_after_worker": live_cuda_tensors_after_worker,
"end_to_end_worker_inference_mode": worker_inference_mode,
}
summary["top_rankings"] = top_rankings(summary)
summary_path = args.output_dir / "t2s_memory_breakdown_summary.json"
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(summary["meta"], ensure_ascii=False, indent=2))
print("[top_rankings]")
for item in summary["top_rankings"]:
print(f"- {item['name']}: {item['mb']:.3f} MB")
print("[worker_peak]")
print(
json.dumps(
{
"peak_label": summary["end_to_end_worker"]["timeline"]["peak_label"],
"peak_allocated_mb": summary["end_to_end_worker"]["timeline"]["peak_allocated_mb"],
"peak_reserved_mb": summary["end_to_end_worker"]["timeline"]["peak_reserved_mb"],
},
ensure_ascii=False,
indent=2,
)
)
if worker_inference_mode is not None:
print("[worker_peak_inference_mode]")
print(
json.dumps(
{
"peak_label": worker_inference_mode["timeline"]["peak_label"],
"peak_allocated_mb": worker_inference_mode["timeline"]["peak_allocated_mb"],
"peak_reserved_mb": worker_inference_mode["timeline"]["peak_reserved_mb"],
},
ensure_ascii=False,
indent=2,
)
)
print(f"[summary] {summary_path}")
if __name__ == "__main__":
main()