GPT-SoVITS/tools/t2s_memory_breakdown.py

888 lines
35 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import argparse
import gc
import contextlib
import json
import random
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
ROOT_DIR = Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS"
if str(gpt_sovits_dir) not in sys.path:
sys.path.append(str(gpt_sovits_dir))
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402
SchedulerRequestSpec,
T2SRequestState,
T2SRunningRequest,
_build_decode_batch_from_running,
build_prefill_batch,
prepare_request_state,
run_decode_step_for_running,
run_prefill_step,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Break down T2S CUDA memory by stage and tensor groups.")
parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml")
parser.add_argument("--request-manifest", type=Path, default=None)
parser.add_argument("--scenario", type=str, default="auto4", choices=["auto4", "single"])
parser.add_argument("--auto-count", type=int, default=4)
parser.add_argument("--auto-wav-dir", type=Path, default=ROOT_DIR / "testwav")
parser.add_argument("--auto-text-file", type=Path, default=ROOT_DIR / "test_cn.txt")
parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav")
parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。")
parser.add_argument("--prompt-lang", type=str, default="zh")
parser.add_argument("--text", type=str, default=None)
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt")
parser.add_argument("--text-lang", type=str, default="zh")
parser.add_argument("--top-k", type=int, default=15)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--repetition-penalty", type=float, default=1.35)
parser.add_argument("--early-stop-num", type=int, default=-1)
parser.add_argument("--max-steps", type=int, default=1500)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--warmup", action="store_true", default=False)
parser.add_argument("--worker-rounds", type=int, default=1)
parser.add_argument("--worker-grad-mode", type=str, default="default", choices=["default", "inference_mode"])
parser.add_argument("--compare-worker-grad-modes", action="store_true", default=False)
parser.add_argument(
"--output-dir",
type=Path,
default=ROOT_DIR / "TEMP/t2s_memory_breakdown/run1",
)
return parser.parse_args()
def set_seed(seed: int, use_cuda: bool) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if use_cuda and torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def _sync_device(device: Any) -> None:
try:
device_str = str(device)
if device_str.startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize(device)
elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"):
torch.mps.synchronize()
except Exception:
pass
def bytes_to_mb(num_bytes: int) -> float:
return float(num_bytes) / (1024.0 * 1024.0)
def tensor_nbytes(tensor: Optional[torch.Tensor]) -> int:
if tensor is None:
return 0
return int(tensor.numel() * tensor.element_size())
def tensor_list_nbytes(items: Sequence[torch.Tensor]) -> int:
return int(sum(tensor_nbytes(item) for item in items))
def model_nbytes(module: torch.nn.Module) -> int:
total = 0
for parameter in module.parameters():
total += tensor_nbytes(parameter)
for buffer in module.buffers():
total += tensor_nbytes(buffer)
return int(total)
def build_module_weight_summary(tts: TTS) -> Dict[str, Any]:
modules = {
"t2s_model": tts.t2s_model,
"t2s_core": tts.t2s_model.model if tts.t2s_model is not None else None,
"vits_model": tts.vits_model,
"bert_model": tts.bert_model,
"cnhuhbert_model": tts.cnhuhbert_model,
"vocoder": tts.vocoder,
"sv_model": tts.sv_model,
}
by_module = {}
total_bytes = 0
for name, module in modules.items():
module_bytes = model_nbytes(module) if module is not None else 0
by_module[name] = {
"bytes": int(module_bytes),
"mb": bytes_to_mb(module_bytes),
}
total_bytes += module_bytes
return {
"by_module": by_module,
"total_bytes": int(total_bytes),
"total_mb": bytes_to_mb(total_bytes),
}
def snapshot_live_cuda_tensors(top_k: int = 40) -> Dict[str, Any]:
storages: Dict[int, Dict[str, Any]] = {}
tensor_views: List[Dict[str, Any]] = []
for obj in gc.get_objects():
try:
tensor = None
if torch.is_tensor(obj):
tensor = obj
elif hasattr(obj, "data") and torch.is_tensor(obj.data):
tensor = obj.data
if tensor is None or not tensor.is_cuda:
continue
storage = tensor.untyped_storage()
storage_ptr = int(storage.data_ptr())
if storage_ptr not in storages:
storages[storage_ptr] = {
"storage_ptr": storage_ptr,
"storage_bytes": int(storage.nbytes()),
"dtype": str(tensor.dtype),
"shape": list(tensor.shape),
"device": str(tensor.device),
}
tensor_views.append(
{
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
"bytes": tensor_nbytes(tensor),
"device": str(tensor.device),
}
)
except Exception:
continue
storage_list = sorted(storages.values(), key=lambda item: item["storage_bytes"], reverse=True)
tensor_views.sort(key=lambda item: item["bytes"], reverse=True)
return {
"unique_storage_count": int(len(storage_list)),
"unique_storage_total_bytes": int(sum(item["storage_bytes"] for item in storage_list)),
"unique_storage_total_mb": bytes_to_mb(sum(item["storage_bytes"] for item in storage_list)),
"top_storages": storage_list[:top_k],
"top_tensor_views": tensor_views[:top_k],
}
def build_single_spec(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8").strip()
return [
SchedulerRequestSpec(
request_id="req_000",
ref_audio_path=args.ref_audio,
prompt_text=args.prompt_text,
prompt_lang=args.prompt_lang,
text=text,
text_lang=args.text_lang,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
early_stop_num=args.early_stop_num,
ready_step=0,
)
]
def build_auto_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
wav_paths = sorted(args.auto_wav_dir.glob("*.wav"))[: args.auto_count]
if len(wav_paths) < args.auto_count:
raise ValueError(f"auto wav count不足目录 {args.auto_wav_dir} 只有 {len(wav_paths)} 条 wav")
text_lines = [line.strip() for line in args.auto_text_file.read_text(encoding="utf-8").splitlines() if line.strip()]
if len(text_lines) < args.auto_count:
raise ValueError(f"auto text lines不足文件 {args.auto_text_file} 只有 {len(text_lines)} 行有效文本")
specs: List[SchedulerRequestSpec] = []
for index, wav_path in enumerate(wav_paths):
lab_path = wav_path.with_suffix(".lab")
if not lab_path.exists():
raise FileNotFoundError(f"找不到参考文本 {lab_path}")
specs.append(
SchedulerRequestSpec(
request_id=f"req_{index:03d}",
ref_audio_path=wav_path,
prompt_text=lab_path.read_text(encoding="utf-8").strip(),
prompt_lang="zh",
text=text_lines[index],
text_lang=args.text_lang,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
early_stop_num=args.early_stop_num,
ready_step=0,
)
)
return specs
def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
if args.request_manifest is not None:
payload = json.loads(args.request_manifest.read_text(encoding="utf-8"))
raw_requests = payload["requests"] if isinstance(payload, dict) else payload
specs: List[SchedulerRequestSpec] = []
for index, item in enumerate(raw_requests):
text = item.get("text")
text_file = item.get("text_file")
if text is None and text_file is None:
raise ValueError(f"request[{index}] must provide text or text_file")
if text is None:
text = Path(text_file).read_text(encoding="utf-8").strip()
specs.append(
SchedulerRequestSpec(
request_id=item.get("request_id", f"req_{index:03d}"),
ref_audio_path=Path(item["ref_audio_path"]),
prompt_text=item["prompt_text"],
prompt_lang=item.get("prompt_lang", "zh"),
text=text,
text_lang=item.get("text_lang", "zh"),
top_k=int(item.get("top_k", args.top_k)),
top_p=float(item.get("top_p", args.top_p)),
temperature=float(item.get("temperature", args.temperature)),
repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)),
early_stop_num=int(item.get("early_stop_num", args.early_stop_num)),
ready_step=int(item.get("ready_step", 0)),
)
)
return specs
if args.scenario == "single":
return build_single_spec(args)
return build_auto_specs(args)
def load_pipeline(config_path: Path) -> TTS:
tts_config = TTS_Config(str(config_path))
print(tts_config)
return TTS(tts_config)
def cuda_mem_snapshot(device: Any) -> Dict[str, float]:
if not (str(device).startswith("cuda") and torch.cuda.is_available()):
return {
"allocated_mb": 0.0,
"reserved_mb": 0.0,
"max_allocated_mb": 0.0,
"max_reserved_mb": 0.0,
}
_sync_device(device)
return {
"allocated_mb": bytes_to_mb(torch.cuda.memory_allocated(device)),
"reserved_mb": bytes_to_mb(torch.cuda.memory_reserved(device)),
"max_allocated_mb": bytes_to_mb(torch.cuda.max_memory_allocated(device)),
"max_reserved_mb": bytes_to_mb(torch.cuda.max_memory_reserved(device)),
}
def stage_run(device: Any, fn) -> Tuple[Any, Dict[str, float]]:
if str(device).startswith("cuda") and torch.cuda.is_available():
gc.collect()
_sync_device(device)
torch.cuda.reset_peak_memory_stats(device)
before = cuda_mem_snapshot(device)
started = time.perf_counter()
result = fn()
_sync_device(device)
elapsed_ms = (time.perf_counter() - started) * 1000.0
after = cuda_mem_snapshot(device)
after["elapsed_ms"] = float(elapsed_ms)
after["delta_allocated_mb"] = float(after["allocated_mb"] - before["allocated_mb"])
after["delta_reserved_mb"] = float(after["reserved_mb"] - before["reserved_mb"])
after["stage_peak_over_before_mb"] = float(max(after["max_allocated_mb"] - before["allocated_mb"], 0.0))
return result, after
class GlobalPeakRecorder:
def __init__(self, device: Any):
self.device = device
self.checkpoints: List[Dict[str, Any]] = []
if str(device).startswith("cuda") and torch.cuda.is_available():
gc.collect()
_sync_device(device)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
def record(self, label: str, **extra: Any) -> None:
snapshot = cuda_mem_snapshot(self.device)
snapshot["label"] = label
snapshot.update(extra)
self.checkpoints.append(snapshot)
def summary(self) -> Dict[str, Any]:
peak = max(self.checkpoints, key=lambda item: item["max_allocated_mb"]) if self.checkpoints else None
return {
"peak_allocated_mb": 0.0 if peak is None else float(peak["max_allocated_mb"]),
"peak_reserved_mb": 0.0 if peak is None else float(peak["max_reserved_mb"]),
"peak_label": None if peak is None else peak["label"],
"checkpoints": self.checkpoints,
}
def summarise_state_tensors(states: Sequence[T2SRequestState]) -> Dict[str, Any]:
per_request = []
total = {
"phones_bytes": 0,
"prompt_phones_bytes": 0,
"all_phones_bytes": 0,
"all_bert_features_bytes": 0,
"prompt_semantic_bytes": 0,
"refer_spec_bytes": 0,
"raw_audio_bytes": 0,
"audio_16k_bytes": 0,
}
for state in states:
spec_audio, audio_16k = state.refer_spec
item = {
"request_id": state.request_id,
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
"phones_len": int(state.phones.shape[0]),
"all_phones_len": int(state.all_phones.shape[0]),
"bert_frames": int(state.all_bert_features.shape[-1]),
"phones_bytes": tensor_nbytes(state.phones),
"prompt_phones_bytes": tensor_nbytes(state.prompt_phones),
"all_phones_bytes": tensor_nbytes(state.all_phones),
"all_bert_features_bytes": tensor_nbytes(state.all_bert_features),
"prompt_semantic_bytes": tensor_nbytes(state.prompt_semantic),
"refer_spec_bytes": tensor_nbytes(spec_audio),
"audio_16k_bytes": tensor_nbytes(audio_16k),
"raw_audio_bytes": tensor_nbytes(state.raw_audio),
}
for key in total:
total[key] += int(item[key])
per_request.append(item)
total["total_bytes"] = int(sum(total.values()))
total["total_mb"] = bytes_to_mb(total["total_bytes"])
return {"per_request": per_request, "total": total}
def summarise_prefill_batch(active_batch: Any) -> Dict[str, Any]:
y_sequence_bytes = int(sum(tensor_nbytes(item) for item in active_batch.y_sequences))
fields = {
"x_bytes": tensor_nbytes(active_batch.x),
"x_lens_bytes": tensor_nbytes(active_batch.x_lens),
"prefix_lens_bytes": tensor_nbytes(active_batch.prefix_lens),
"xy_pos_bytes": tensor_nbytes(active_batch.xy_pos),
"key_padding_mask_bytes": tensor_nbytes(active_batch.key_padding_mask),
"prefill_attn_mask_bytes": tensor_nbytes(active_batch.prefill_attn_mask),
"y_sequence_bytes": y_sequence_bytes,
}
fields["total_bytes"] = int(sum(fields.values()))
fields["total_mb"] = bytes_to_mb(fields["total_bytes"])
fields["batch_size"] = int(len(active_batch.states))
fields["max_x_len"] = int(active_batch.x.shape[1])
fields["src_len"] = int(active_batch.xy_pos.shape[1])
fields["prefill_attn_mask_shape"] = list(active_batch.prefill_attn_mask.shape)
return fields
def summarise_running_requests(running_requests: Sequence[T2SRunningRequest]) -> Dict[str, Any]:
per_request = []
total_private_k_bytes = 0
total_private_v_bytes = 0
total_decode_mask_bytes = 0
total_y_sequence_bytes = 0
for item in running_requests:
k_bytes = tensor_list_nbytes(item.k_cache)
v_bytes = tensor_list_nbytes(item.v_cache)
mask_bytes = tensor_nbytes(item.decode_attn_mask)
y_bytes = tensor_nbytes(item.y_sequence)
total_private_k_bytes += k_bytes
total_private_v_bytes += v_bytes
total_decode_mask_bytes += mask_bytes
total_y_sequence_bytes += y_bytes
per_request.append(
{
"request_id": item.state.request_id,
"step_idx": int(item.step_idx),
"prefix_len": int(item.prefix_len),
"history_len": int(item.y_sequence.shape[0]),
"kv_len": int(item.k_cache[0].shape[1]),
"k_cache_bytes": k_bytes,
"v_cache_bytes": v_bytes,
"decode_mask_bytes": mask_bytes,
"y_sequence_bytes": y_bytes,
}
)
total_bytes = total_private_k_bytes + total_private_v_bytes + total_decode_mask_bytes + total_y_sequence_bytes
return {
"per_request": per_request,
"totals": {
"private_k_cache_bytes": int(total_private_k_bytes),
"private_v_cache_bytes": int(total_private_v_bytes),
"private_kv_cache_bytes": int(total_private_k_bytes + total_private_v_bytes),
"decode_mask_bytes": int(total_decode_mask_bytes),
"y_sequence_bytes": int(total_y_sequence_bytes),
"total_bytes": int(total_bytes),
"total_mb": bytes_to_mb(total_bytes),
},
}
def summarise_decode_batch(
xy_pos: torch.Tensor,
batched_k_cache: Sequence[torch.Tensor],
batched_v_cache: Sequence[torch.Tensor],
batched_decode_attn_mask: Optional[torch.Tensor],
running_requests: Sequence[T2SRunningRequest],
) -> Dict[str, Any]:
private_k_bytes = int(sum(tensor_list_nbytes(item.k_cache) for item in running_requests))
private_v_bytes = int(sum(tensor_list_nbytes(item.v_cache) for item in running_requests))
batched_k_bytes = tensor_list_nbytes(batched_k_cache)
batched_v_bytes = tensor_list_nbytes(batched_v_cache)
batched_mask_bytes = tensor_nbytes(batched_decode_attn_mask)
xy_pos_bytes = tensor_nbytes(xy_pos)
total_bytes = batched_k_bytes + batched_v_bytes + batched_mask_bytes + xy_pos_bytes
return {
"batch_size": int(len(running_requests)),
"xy_pos_bytes": int(xy_pos_bytes),
"batched_k_cache_bytes": int(batched_k_bytes),
"batched_v_cache_bytes": int(batched_v_bytes),
"batched_kv_cache_bytes": int(batched_k_bytes + batched_v_bytes),
"batched_decode_mask_bytes": int(batched_mask_bytes),
"private_kv_cache_bytes_reference": int(private_k_bytes + private_v_bytes),
"kv_padding_overhead_bytes": int((batched_k_bytes + batched_v_bytes) - (private_k_bytes + private_v_bytes)),
"total_bytes": int(total_bytes),
"total_mb": bytes_to_mb(total_bytes),
"xy_pos_shape": list(xy_pos.shape),
"batched_decode_mask_shape": None if batched_decode_attn_mask is None else list(batched_decode_attn_mask.shape),
"layer_k_cache_shape": list(batched_k_cache[0].shape),
}
def summarise_decode_outputs(
xy_dec: torch.Tensor,
next_k_cache: Sequence[torch.Tensor],
next_v_cache: Sequence[torch.Tensor],
) -> Dict[str, Any]:
xy_dec_bytes = tensor_nbytes(xy_dec)
next_k_bytes = tensor_list_nbytes(next_k_cache)
next_v_bytes = tensor_list_nbytes(next_v_cache)
total_bytes = xy_dec_bytes + next_k_bytes + next_v_bytes
return {
"xy_dec_bytes": int(xy_dec_bytes),
"next_k_cache_bytes": int(next_k_bytes),
"next_v_cache_bytes": int(next_v_bytes),
"next_kv_cache_bytes": int(next_k_bytes + next_v_bytes),
"total_bytes": int(total_bytes),
"total_mb": bytes_to_mb(total_bytes),
"xy_dec_shape": list(xy_dec.shape),
"layer_next_k_cache_shape": list(next_k_cache[0].shape),
}
def top_rankings(summary: Dict[str, Any]) -> List[Dict[str, Any]]:
ranking = [
("request_state_total", summary["prepare_stage"]["request_state"]["total"]["total_bytes"]),
("prefill_batch_total", summary["prefill_batch"]["tensor_bytes"]["total_bytes"]),
("running_private_kv", summary["prefill_step"]["running_requests"]["totals"]["private_kv_cache_bytes"]),
("decode_batched_kv", summary["decode_batch"]["tensor_bytes"]["batched_kv_cache_bytes"]),
("decode_kv_padding_overhead", summary["decode_batch"]["tensor_bytes"]["kv_padding_overhead_bytes"]),
("decode_outputs_next_kv", summary["decode_outputs"]["tensor_bytes"]["next_kv_cache_bytes"]),
("prefill_attn_mask", summary["prefill_batch"]["tensor_bytes"]["prefill_attn_mask_bytes"]),
]
ranking.sort(key=lambda item: item[1], reverse=True)
return [{"name": name, "bytes": int(value), "mb": bytes_to_mb(int(value))} for name, value in ranking]
def synthesize_finished_item(tts: TTS, state: T2SRequestState, semantic_tokens: torch.Tensor) -> Tuple[int, np.ndarray]:
semantic_tokens = semantic_tokens.unsqueeze(0).unsqueeze(0).to(tts.configs.device)
phones = state.phones.unsqueeze(0).to(tts.configs.device)
audio_fragment = tts.synthesize_audio_request_local(
semantic_tokens=semantic_tokens,
phones=phones,
prompt_semantic=state.prompt_semantic,
prompt_phones=state.prompt_phones,
refer_spec=state.refer_spec,
raw_audio=state.raw_audio,
raw_sr=state.raw_sr,
speed=1.0,
sample_steps=32,
)
output_sr = tts.configs.sampling_rate if not tts.configs.use_vocoder else tts.vocoder_configs["sr"]
return tts.audio_postprocess(
audio=[[audio_fragment]],
sr=int(output_sr),
batch_index_list=None,
speed_factor=1.0,
split_bucket=False,
fragment_interval=0.0,
super_sampling=False,
)
def simulate_worker_end_to_end(
tts: TTS,
specs: Sequence[SchedulerRequestSpec],
max_steps: int,
rounds: int,
grad_mode: str = "default",
) -> Dict[str, Any]:
device = tts.configs.device
recorder = GlobalPeakRecorder(device)
recorder.record("after_model_load")
state_map: Dict[str, T2SRequestState] = {}
per_round: List[Dict[str, Any]] = []
for round_index in range(rounds):
grad_context = torch.inference_mode if grad_mode == "inference_mode" else contextlib.nullcontext
with grad_context():
states = [prepare_request_state(tts, spec) for spec in specs]
state_map = {state.request_id: state for state in states}
recorder.record(
"after_prepare_states",
round_index=int(round_index),
request_count=int(len(states)),
grad_mode=grad_mode,
)
pending = list(states)
running_requests: List[T2SRunningRequest] = []
round_events: List[Dict[str, Any]] = []
current_tick = 0
while pending or running_requests:
admitted = pending
pending = []
if admitted:
recorder.record(
"before_prefill",
round_index=int(round_index),
tick=int(current_tick),
admitted_count=int(len(admitted)),
running_count=int(len(running_requests)),
grad_mode=grad_mode,
)
with grad_context():
admitted_running, admitted_finished = run_prefill_step(tts.t2s_model.model, admitted, max_steps=max_steps)
recorder.record(
"after_prefill",
round_index=int(round_index),
tick=int(current_tick),
admitted_running_count=int(len(admitted_running)),
admitted_finished_count=int(len(admitted_finished)),
running_count=int(len(running_requests)),
grad_mode=grad_mode,
)
round_events.append(
{
"tick": int(current_tick),
"event": "prefill",
"admitted_count": int(len(admitted)),
"admitted_running_count": int(len(admitted_running)),
"admitted_finished_count": int(len(admitted_finished)),
}
)
for item in admitted_finished:
recorder.record(
"before_synth_prefill_finished",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_request_id=item.request_id,
semantic_len=int(item.semantic_tokens.shape[0]),
grad_mode=grad_mode,
)
with grad_context():
sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens)
recorder.record(
"after_synth_prefill_finished",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_request_id=item.request_id,
sample_rate=int(sample_rate),
audio_samples=int(audio_data.shape[0]),
grad_mode=grad_mode,
)
running_requests.extend(admitted_running)
recorder.record(
"after_extend_running",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
grad_mode=grad_mode,
)
if running_requests:
recorder.record(
"before_decode",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
grad_mode=grad_mode,
)
with grad_context():
running_requests, step_finished = run_decode_step_for_running(
tts.t2s_model.model,
running_requests,
max_steps=max_steps,
)
recorder.record(
"after_decode",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_count=int(len(step_finished)),
grad_mode=grad_mode,
)
round_events.append(
{
"tick": int(current_tick),
"event": "decode",
"running_count_after_decode": int(len(running_requests)),
"finished_count": int(len(step_finished)),
}
)
for item in step_finished:
recorder.record(
"before_synth_decode_finished",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_request_id=item.request_id,
semantic_len=int(item.semantic_tokens.shape[0]),
grad_mode=grad_mode,
)
with grad_context():
sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens)
recorder.record(
"after_synth_decode_finished",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_request_id=item.request_id,
sample_rate=int(sample_rate),
audio_samples=int(audio_data.shape[0]),
grad_mode=grad_mode,
)
current_tick += 1
recorder.record(
"after_round_complete",
round_index=int(round_index),
running_count=0,
grad_mode=grad_mode,
)
per_round.append(
{
"round_index": int(round_index),
"events": round_events,
}
)
return {
"grad_mode": grad_mode,
"rounds": per_round,
"timeline": recorder.summary(),
}
def main() -> None:
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
tts = load_pipeline(args.config)
model = tts.t2s_model.model
device = tts.configs.device
use_cuda = str(device).startswith("cuda") and torch.cuda.is_available()
set_seed(args.seed, use_cuda)
specs = load_request_specs(args)
if args.early_stop_num == -1:
for spec in specs:
spec.early_stop_num = int(tts.configs.hz * tts.configs.max_sec)
if args.warmup and specs:
warmup_spec = specs[:1]
_ = [prepare_request_state(tts, spec) for spec in warmup_spec]
gc.collect()
if use_cuda:
torch.cuda.empty_cache()
_sync_device(device)
states, prepare_mem = stage_run(device, lambda: [prepare_request_state(tts, spec) for spec in specs])
request_state_summary = summarise_state_tensors(states)
active_batch, prefill_batch_mem = stage_run(device, lambda: build_prefill_batch(model, states))
prefill_batch_tensor_summary = summarise_prefill_batch(active_batch)
prefill_result, prefill_step_mem = stage_run(device, lambda: run_prefill_step(model, states, max_steps=args.max_steps))
running_requests, finished_items = prefill_result
running_requests_summary = summarise_running_requests(running_requests)
finished_after_prefill_summary = [
{
"request_id": item.request_id,
"finish_idx": int(item.finish_idx),
"finish_reason": item.finish_reason,
"semantic_len": int(item.semantic_tokens.shape[0]),
}
for item in finished_items
]
if not running_requests:
raise RuntimeError(f"prefill 后没有 running requests全部在首步结束: {[item.request_id for item in finished_items]}")
decode_batch_result, decode_batch_mem = stage_run(
device,
lambda: _build_decode_batch_from_running(model, running_requests),
)
xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = decode_batch_result
decode_batch_tensor_summary = summarise_decode_batch(
xy_pos,
batched_k_cache,
batched_v_cache,
batched_decode_attn_mask,
running_requests,
)
decode_out_result, decode_step_mem = stage_run(
device,
lambda: model.t2s_transformer.decode_next_token(
xy_pos,
batched_k_cache,
batched_v_cache,
batched_decode_attn_mask,
),
)
xy_dec, next_k_cache, next_v_cache = decode_out_result
decode_output_tensor_summary = summarise_decode_outputs(xy_dec, next_k_cache, next_v_cache)
del active_batch
del running_requests
del finished_items
del xy_pos
del batched_k_cache
del batched_v_cache
del batched_decode_attn_mask
del xy_dec
del next_k_cache
del next_v_cache
gc.collect()
if use_cuda:
_sync_device(device)
torch.cuda.empty_cache()
end_to_end_worker = simulate_worker_end_to_end(
tts=tts,
specs=specs,
max_steps=args.max_steps,
rounds=args.worker_rounds,
grad_mode=args.worker_grad_mode,
)
live_cuda_tensors_after_worker = snapshot_live_cuda_tensors()
worker_inference_mode = None
if args.compare_worker_grad_modes:
gc.collect()
if use_cuda:
_sync_device(device)
torch.cuda.empty_cache()
worker_inference_mode = simulate_worker_end_to_end(
tts=tts,
specs=specs,
max_steps=args.max_steps,
rounds=args.worker_rounds,
grad_mode="inference_mode",
)
summary = {
"meta": {
"scenario": args.scenario if args.request_manifest is None else "manifest",
"seed": int(args.seed),
"device": str(device),
"dtype": str(next(model.parameters()).dtype),
"request_count": int(len(specs)),
"num_layers": int(model.num_layers),
"num_heads": int(model.num_head),
"model_dim": int(model.model_dim),
"model_weights_mb": bytes_to_mb(model_nbytes(model)),
},
"loaded_module_weights": build_module_weight_summary(tts),
"requests": [
{
"request_id": spec.request_id,
"ref_audio_path": str(spec.ref_audio_path),
"prompt_text": spec.prompt_text,
"text": spec.text,
}
for spec in specs
],
"prepare_stage": {
"memory": prepare_mem,
"request_state": request_state_summary,
},
"prefill_batch": {
"memory": prefill_batch_mem,
"tensor_bytes": prefill_batch_tensor_summary,
},
"prefill_step": {
"memory": prefill_step_mem,
"running_requests": running_requests_summary,
"finished_after_prefill": finished_after_prefill_summary,
},
"decode_batch": {
"memory": decode_batch_mem,
"tensor_bytes": decode_batch_tensor_summary,
},
"decode_outputs": {
"memory": decode_step_mem,
"tensor_bytes": decode_output_tensor_summary,
},
"end_to_end_worker": end_to_end_worker,
"live_cuda_tensors_after_worker": live_cuda_tensors_after_worker,
"end_to_end_worker_inference_mode": worker_inference_mode,
}
summary["top_rankings"] = top_rankings(summary)
summary_path = args.output_dir / "t2s_memory_breakdown_summary.json"
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(summary["meta"], ensure_ascii=False, indent=2))
print("[top_rankings]")
for item in summary["top_rankings"]:
print(f"- {item['name']}: {item['mb']:.3f} MB")
print("[worker_peak]")
print(
json.dumps(
{
"peak_label": summary["end_to_end_worker"]["timeline"]["peak_label"],
"peak_allocated_mb": summary["end_to_end_worker"]["timeline"]["peak_allocated_mb"],
"peak_reserved_mb": summary["end_to_end_worker"]["timeline"]["peak_reserved_mb"],
},
ensure_ascii=False,
indent=2,
)
)
if worker_inference_mode is not None:
print("[worker_peak_inference_mode]")
print(
json.dumps(
{
"peak_label": worker_inference_mode["timeline"]["peak_label"],
"peak_allocated_mb": worker_inference_mode["timeline"]["peak_allocated_mb"],
"peak_reserved_mb": worker_inference_mode["timeline"]["peak_reserved_mb"],
},
ensure_ascii=False,
indent=2,
)
)
print(f"[summary] {summary_path}")
if __name__ == "__main__":
main()