mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-13 05:18:12 +08:00
888 lines
35 KiB
Python
888 lines
35 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import gc
|
||
import contextlib
|
||
import json
|
||
import random
|
||
import sys
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||
|
||
import numpy as np
|
||
import torch
|
||
|
||
ROOT_DIR = Path(__file__).resolve().parents[1]
|
||
if str(ROOT_DIR) not in sys.path:
|
||
sys.path.append(str(ROOT_DIR))
|
||
gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS"
|
||
if str(gpt_sovits_dir) not in sys.path:
|
||
sys.path.append(str(gpt_sovits_dir))
|
||
|
||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
|
||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402
|
||
SchedulerRequestSpec,
|
||
T2SRequestState,
|
||
T2SRunningRequest,
|
||
_build_decode_batch_from_running,
|
||
build_prefill_batch,
|
||
prepare_request_state,
|
||
run_decode_step_for_running,
|
||
run_prefill_step,
|
||
)
|
||
|
||
|
||
def parse_args() -> argparse.Namespace:
|
||
parser = argparse.ArgumentParser(description="Break down T2S CUDA memory by stage and tensor groups.")
|
||
parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml")
|
||
parser.add_argument("--request-manifest", type=Path, default=None)
|
||
parser.add_argument("--scenario", type=str, default="auto4", choices=["auto4", "single"])
|
||
parser.add_argument("--auto-count", type=int, default=4)
|
||
parser.add_argument("--auto-wav-dir", type=Path, default=ROOT_DIR / "testwav")
|
||
parser.add_argument("--auto-text-file", type=Path, default=ROOT_DIR / "test_cn.txt")
|
||
parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav")
|
||
parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。")
|
||
parser.add_argument("--prompt-lang", type=str, default="zh")
|
||
parser.add_argument("--text", type=str, default=None)
|
||
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt")
|
||
parser.add_argument("--text-lang", type=str, default="zh")
|
||
parser.add_argument("--top-k", type=int, default=15)
|
||
parser.add_argument("--top-p", type=float, default=1.0)
|
||
parser.add_argument("--temperature", type=float, default=1.0)
|
||
parser.add_argument("--repetition-penalty", type=float, default=1.35)
|
||
parser.add_argument("--early-stop-num", type=int, default=-1)
|
||
parser.add_argument("--max-steps", type=int, default=1500)
|
||
parser.add_argument("--seed", type=int, default=1234)
|
||
parser.add_argument("--warmup", action="store_true", default=False)
|
||
parser.add_argument("--worker-rounds", type=int, default=1)
|
||
parser.add_argument("--worker-grad-mode", type=str, default="default", choices=["default", "inference_mode"])
|
||
parser.add_argument("--compare-worker-grad-modes", action="store_true", default=False)
|
||
parser.add_argument(
|
||
"--output-dir",
|
||
type=Path,
|
||
default=ROOT_DIR / "TEMP/t2s_memory_breakdown/run1",
|
||
)
|
||
return parser.parse_args()
|
||
|
||
|
||
def set_seed(seed: int, use_cuda: bool) -> None:
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
if use_cuda and torch.cuda.is_available():
|
||
torch.cuda.manual_seed(seed)
|
||
torch.cuda.manual_seed_all(seed)
|
||
|
||
|
||
def _sync_device(device: Any) -> None:
|
||
try:
|
||
device_str = str(device)
|
||
if device_str.startswith("cuda") and torch.cuda.is_available():
|
||
torch.cuda.synchronize(device)
|
||
elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"):
|
||
torch.mps.synchronize()
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
def bytes_to_mb(num_bytes: int) -> float:
|
||
return float(num_bytes) / (1024.0 * 1024.0)
|
||
|
||
|
||
def tensor_nbytes(tensor: Optional[torch.Tensor]) -> int:
|
||
if tensor is None:
|
||
return 0
|
||
return int(tensor.numel() * tensor.element_size())
|
||
|
||
|
||
def tensor_list_nbytes(items: Sequence[torch.Tensor]) -> int:
|
||
return int(sum(tensor_nbytes(item) for item in items))
|
||
|
||
|
||
def model_nbytes(module: torch.nn.Module) -> int:
|
||
total = 0
|
||
for parameter in module.parameters():
|
||
total += tensor_nbytes(parameter)
|
||
for buffer in module.buffers():
|
||
total += tensor_nbytes(buffer)
|
||
return int(total)
|
||
|
||
|
||
def build_module_weight_summary(tts: TTS) -> Dict[str, Any]:
|
||
modules = {
|
||
"t2s_model": tts.t2s_model,
|
||
"t2s_core": tts.t2s_model.model if tts.t2s_model is not None else None,
|
||
"vits_model": tts.vits_model,
|
||
"bert_model": tts.bert_model,
|
||
"cnhuhbert_model": tts.cnhuhbert_model,
|
||
"vocoder": tts.vocoder,
|
||
"sv_model": tts.sv_model,
|
||
}
|
||
by_module = {}
|
||
total_bytes = 0
|
||
for name, module in modules.items():
|
||
module_bytes = model_nbytes(module) if module is not None else 0
|
||
by_module[name] = {
|
||
"bytes": int(module_bytes),
|
||
"mb": bytes_to_mb(module_bytes),
|
||
}
|
||
total_bytes += module_bytes
|
||
return {
|
||
"by_module": by_module,
|
||
"total_bytes": int(total_bytes),
|
||
"total_mb": bytes_to_mb(total_bytes),
|
||
}
|
||
|
||
|
||
def snapshot_live_cuda_tensors(top_k: int = 40) -> Dict[str, Any]:
|
||
storages: Dict[int, Dict[str, Any]] = {}
|
||
tensor_views: List[Dict[str, Any]] = []
|
||
for obj in gc.get_objects():
|
||
try:
|
||
tensor = None
|
||
if torch.is_tensor(obj):
|
||
tensor = obj
|
||
elif hasattr(obj, "data") and torch.is_tensor(obj.data):
|
||
tensor = obj.data
|
||
if tensor is None or not tensor.is_cuda:
|
||
continue
|
||
storage = tensor.untyped_storage()
|
||
storage_ptr = int(storage.data_ptr())
|
||
if storage_ptr not in storages:
|
||
storages[storage_ptr] = {
|
||
"storage_ptr": storage_ptr,
|
||
"storage_bytes": int(storage.nbytes()),
|
||
"dtype": str(tensor.dtype),
|
||
"shape": list(tensor.shape),
|
||
"device": str(tensor.device),
|
||
}
|
||
tensor_views.append(
|
||
{
|
||
"shape": list(tensor.shape),
|
||
"dtype": str(tensor.dtype),
|
||
"bytes": tensor_nbytes(tensor),
|
||
"device": str(tensor.device),
|
||
}
|
||
)
|
||
except Exception:
|
||
continue
|
||
storage_list = sorted(storages.values(), key=lambda item: item["storage_bytes"], reverse=True)
|
||
tensor_views.sort(key=lambda item: item["bytes"], reverse=True)
|
||
return {
|
||
"unique_storage_count": int(len(storage_list)),
|
||
"unique_storage_total_bytes": int(sum(item["storage_bytes"] for item in storage_list)),
|
||
"unique_storage_total_mb": bytes_to_mb(sum(item["storage_bytes"] for item in storage_list)),
|
||
"top_storages": storage_list[:top_k],
|
||
"top_tensor_views": tensor_views[:top_k],
|
||
}
|
||
|
||
|
||
def build_single_spec(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
|
||
text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8").strip()
|
||
return [
|
||
SchedulerRequestSpec(
|
||
request_id="req_000",
|
||
ref_audio_path=args.ref_audio,
|
||
prompt_text=args.prompt_text,
|
||
prompt_lang=args.prompt_lang,
|
||
text=text,
|
||
text_lang=args.text_lang,
|
||
top_k=args.top_k,
|
||
top_p=args.top_p,
|
||
temperature=args.temperature,
|
||
repetition_penalty=args.repetition_penalty,
|
||
early_stop_num=args.early_stop_num,
|
||
ready_step=0,
|
||
)
|
||
]
|
||
|
||
|
||
def build_auto_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
|
||
wav_paths = sorted(args.auto_wav_dir.glob("*.wav"))[: args.auto_count]
|
||
if len(wav_paths) < args.auto_count:
|
||
raise ValueError(f"auto wav count不足,目录 {args.auto_wav_dir} 只有 {len(wav_paths)} 条 wav")
|
||
text_lines = [line.strip() for line in args.auto_text_file.read_text(encoding="utf-8").splitlines() if line.strip()]
|
||
if len(text_lines) < args.auto_count:
|
||
raise ValueError(f"auto text lines不足,文件 {args.auto_text_file} 只有 {len(text_lines)} 行有效文本")
|
||
specs: List[SchedulerRequestSpec] = []
|
||
for index, wav_path in enumerate(wav_paths):
|
||
lab_path = wav_path.with_suffix(".lab")
|
||
if not lab_path.exists():
|
||
raise FileNotFoundError(f"找不到参考文本 {lab_path}")
|
||
specs.append(
|
||
SchedulerRequestSpec(
|
||
request_id=f"req_{index:03d}",
|
||
ref_audio_path=wav_path,
|
||
prompt_text=lab_path.read_text(encoding="utf-8").strip(),
|
||
prompt_lang="zh",
|
||
text=text_lines[index],
|
||
text_lang=args.text_lang,
|
||
top_k=args.top_k,
|
||
top_p=args.top_p,
|
||
temperature=args.temperature,
|
||
repetition_penalty=args.repetition_penalty,
|
||
early_stop_num=args.early_stop_num,
|
||
ready_step=0,
|
||
)
|
||
)
|
||
return specs
|
||
|
||
|
||
def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
|
||
if args.request_manifest is not None:
|
||
payload = json.loads(args.request_manifest.read_text(encoding="utf-8"))
|
||
raw_requests = payload["requests"] if isinstance(payload, dict) else payload
|
||
specs: List[SchedulerRequestSpec] = []
|
||
for index, item in enumerate(raw_requests):
|
||
text = item.get("text")
|
||
text_file = item.get("text_file")
|
||
if text is None and text_file is None:
|
||
raise ValueError(f"request[{index}] must provide text or text_file")
|
||
if text is None:
|
||
text = Path(text_file).read_text(encoding="utf-8").strip()
|
||
specs.append(
|
||
SchedulerRequestSpec(
|
||
request_id=item.get("request_id", f"req_{index:03d}"),
|
||
ref_audio_path=Path(item["ref_audio_path"]),
|
||
prompt_text=item["prompt_text"],
|
||
prompt_lang=item.get("prompt_lang", "zh"),
|
||
text=text,
|
||
text_lang=item.get("text_lang", "zh"),
|
||
top_k=int(item.get("top_k", args.top_k)),
|
||
top_p=float(item.get("top_p", args.top_p)),
|
||
temperature=float(item.get("temperature", args.temperature)),
|
||
repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)),
|
||
early_stop_num=int(item.get("early_stop_num", args.early_stop_num)),
|
||
ready_step=int(item.get("ready_step", 0)),
|
||
)
|
||
)
|
||
return specs
|
||
if args.scenario == "single":
|
||
return build_single_spec(args)
|
||
return build_auto_specs(args)
|
||
|
||
|
||
def load_pipeline(config_path: Path) -> TTS:
|
||
tts_config = TTS_Config(str(config_path))
|
||
print(tts_config)
|
||
return TTS(tts_config)
|
||
|
||
|
||
def cuda_mem_snapshot(device: Any) -> Dict[str, float]:
|
||
if not (str(device).startswith("cuda") and torch.cuda.is_available()):
|
||
return {
|
||
"allocated_mb": 0.0,
|
||
"reserved_mb": 0.0,
|
||
"max_allocated_mb": 0.0,
|
||
"max_reserved_mb": 0.0,
|
||
}
|
||
_sync_device(device)
|
||
return {
|
||
"allocated_mb": bytes_to_mb(torch.cuda.memory_allocated(device)),
|
||
"reserved_mb": bytes_to_mb(torch.cuda.memory_reserved(device)),
|
||
"max_allocated_mb": bytes_to_mb(torch.cuda.max_memory_allocated(device)),
|
||
"max_reserved_mb": bytes_to_mb(torch.cuda.max_memory_reserved(device)),
|
||
}
|
||
|
||
|
||
def stage_run(device: Any, fn) -> Tuple[Any, Dict[str, float]]:
|
||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||
gc.collect()
|
||
_sync_device(device)
|
||
torch.cuda.reset_peak_memory_stats(device)
|
||
before = cuda_mem_snapshot(device)
|
||
started = time.perf_counter()
|
||
result = fn()
|
||
_sync_device(device)
|
||
elapsed_ms = (time.perf_counter() - started) * 1000.0
|
||
after = cuda_mem_snapshot(device)
|
||
after["elapsed_ms"] = float(elapsed_ms)
|
||
after["delta_allocated_mb"] = float(after["allocated_mb"] - before["allocated_mb"])
|
||
after["delta_reserved_mb"] = float(after["reserved_mb"] - before["reserved_mb"])
|
||
after["stage_peak_over_before_mb"] = float(max(after["max_allocated_mb"] - before["allocated_mb"], 0.0))
|
||
return result, after
|
||
|
||
|
||
class GlobalPeakRecorder:
|
||
def __init__(self, device: Any):
|
||
self.device = device
|
||
self.checkpoints: List[Dict[str, Any]] = []
|
||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||
gc.collect()
|
||
_sync_device(device)
|
||
torch.cuda.empty_cache()
|
||
torch.cuda.reset_peak_memory_stats(device)
|
||
|
||
def record(self, label: str, **extra: Any) -> None:
|
||
snapshot = cuda_mem_snapshot(self.device)
|
||
snapshot["label"] = label
|
||
snapshot.update(extra)
|
||
self.checkpoints.append(snapshot)
|
||
|
||
def summary(self) -> Dict[str, Any]:
|
||
peak = max(self.checkpoints, key=lambda item: item["max_allocated_mb"]) if self.checkpoints else None
|
||
return {
|
||
"peak_allocated_mb": 0.0 if peak is None else float(peak["max_allocated_mb"]),
|
||
"peak_reserved_mb": 0.0 if peak is None else float(peak["max_reserved_mb"]),
|
||
"peak_label": None if peak is None else peak["label"],
|
||
"checkpoints": self.checkpoints,
|
||
}
|
||
|
||
|
||
def summarise_state_tensors(states: Sequence[T2SRequestState]) -> Dict[str, Any]:
|
||
per_request = []
|
||
total = {
|
||
"phones_bytes": 0,
|
||
"prompt_phones_bytes": 0,
|
||
"all_phones_bytes": 0,
|
||
"all_bert_features_bytes": 0,
|
||
"prompt_semantic_bytes": 0,
|
||
"refer_spec_bytes": 0,
|
||
"raw_audio_bytes": 0,
|
||
"audio_16k_bytes": 0,
|
||
}
|
||
for state in states:
|
||
spec_audio, audio_16k = state.refer_spec
|
||
item = {
|
||
"request_id": state.request_id,
|
||
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
|
||
"phones_len": int(state.phones.shape[0]),
|
||
"all_phones_len": int(state.all_phones.shape[0]),
|
||
"bert_frames": int(state.all_bert_features.shape[-1]),
|
||
"phones_bytes": tensor_nbytes(state.phones),
|
||
"prompt_phones_bytes": tensor_nbytes(state.prompt_phones),
|
||
"all_phones_bytes": tensor_nbytes(state.all_phones),
|
||
"all_bert_features_bytes": tensor_nbytes(state.all_bert_features),
|
||
"prompt_semantic_bytes": tensor_nbytes(state.prompt_semantic),
|
||
"refer_spec_bytes": tensor_nbytes(spec_audio),
|
||
"audio_16k_bytes": tensor_nbytes(audio_16k),
|
||
"raw_audio_bytes": tensor_nbytes(state.raw_audio),
|
||
}
|
||
for key in total:
|
||
total[key] += int(item[key])
|
||
per_request.append(item)
|
||
total["total_bytes"] = int(sum(total.values()))
|
||
total["total_mb"] = bytes_to_mb(total["total_bytes"])
|
||
return {"per_request": per_request, "total": total}
|
||
|
||
|
||
def summarise_prefill_batch(active_batch: Any) -> Dict[str, Any]:
|
||
y_sequence_bytes = int(sum(tensor_nbytes(item) for item in active_batch.y_sequences))
|
||
fields = {
|
||
"x_bytes": tensor_nbytes(active_batch.x),
|
||
"x_lens_bytes": tensor_nbytes(active_batch.x_lens),
|
||
"prefix_lens_bytes": tensor_nbytes(active_batch.prefix_lens),
|
||
"xy_pos_bytes": tensor_nbytes(active_batch.xy_pos),
|
||
"key_padding_mask_bytes": tensor_nbytes(active_batch.key_padding_mask),
|
||
"prefill_attn_mask_bytes": tensor_nbytes(active_batch.prefill_attn_mask),
|
||
"y_sequence_bytes": y_sequence_bytes,
|
||
}
|
||
fields["total_bytes"] = int(sum(fields.values()))
|
||
fields["total_mb"] = bytes_to_mb(fields["total_bytes"])
|
||
fields["batch_size"] = int(len(active_batch.states))
|
||
fields["max_x_len"] = int(active_batch.x.shape[1])
|
||
fields["src_len"] = int(active_batch.xy_pos.shape[1])
|
||
fields["prefill_attn_mask_shape"] = list(active_batch.prefill_attn_mask.shape)
|
||
return fields
|
||
|
||
|
||
def summarise_running_requests(running_requests: Sequence[T2SRunningRequest]) -> Dict[str, Any]:
|
||
per_request = []
|
||
total_private_k_bytes = 0
|
||
total_private_v_bytes = 0
|
||
total_decode_mask_bytes = 0
|
||
total_y_sequence_bytes = 0
|
||
for item in running_requests:
|
||
k_bytes = tensor_list_nbytes(item.k_cache)
|
||
v_bytes = tensor_list_nbytes(item.v_cache)
|
||
mask_bytes = tensor_nbytes(item.decode_attn_mask)
|
||
y_bytes = tensor_nbytes(item.y_sequence)
|
||
total_private_k_bytes += k_bytes
|
||
total_private_v_bytes += v_bytes
|
||
total_decode_mask_bytes += mask_bytes
|
||
total_y_sequence_bytes += y_bytes
|
||
per_request.append(
|
||
{
|
||
"request_id": item.state.request_id,
|
||
"step_idx": int(item.step_idx),
|
||
"prefix_len": int(item.prefix_len),
|
||
"history_len": int(item.y_sequence.shape[0]),
|
||
"kv_len": int(item.k_cache[0].shape[1]),
|
||
"k_cache_bytes": k_bytes,
|
||
"v_cache_bytes": v_bytes,
|
||
"decode_mask_bytes": mask_bytes,
|
||
"y_sequence_bytes": y_bytes,
|
||
}
|
||
)
|
||
total_bytes = total_private_k_bytes + total_private_v_bytes + total_decode_mask_bytes + total_y_sequence_bytes
|
||
return {
|
||
"per_request": per_request,
|
||
"totals": {
|
||
"private_k_cache_bytes": int(total_private_k_bytes),
|
||
"private_v_cache_bytes": int(total_private_v_bytes),
|
||
"private_kv_cache_bytes": int(total_private_k_bytes + total_private_v_bytes),
|
||
"decode_mask_bytes": int(total_decode_mask_bytes),
|
||
"y_sequence_bytes": int(total_y_sequence_bytes),
|
||
"total_bytes": int(total_bytes),
|
||
"total_mb": bytes_to_mb(total_bytes),
|
||
},
|
||
}
|
||
|
||
|
||
def summarise_decode_batch(
|
||
xy_pos: torch.Tensor,
|
||
batched_k_cache: Sequence[torch.Tensor],
|
||
batched_v_cache: Sequence[torch.Tensor],
|
||
batched_decode_attn_mask: Optional[torch.Tensor],
|
||
running_requests: Sequence[T2SRunningRequest],
|
||
) -> Dict[str, Any]:
|
||
private_k_bytes = int(sum(tensor_list_nbytes(item.k_cache) for item in running_requests))
|
||
private_v_bytes = int(sum(tensor_list_nbytes(item.v_cache) for item in running_requests))
|
||
batched_k_bytes = tensor_list_nbytes(batched_k_cache)
|
||
batched_v_bytes = tensor_list_nbytes(batched_v_cache)
|
||
batched_mask_bytes = tensor_nbytes(batched_decode_attn_mask)
|
||
xy_pos_bytes = tensor_nbytes(xy_pos)
|
||
total_bytes = batched_k_bytes + batched_v_bytes + batched_mask_bytes + xy_pos_bytes
|
||
return {
|
||
"batch_size": int(len(running_requests)),
|
||
"xy_pos_bytes": int(xy_pos_bytes),
|
||
"batched_k_cache_bytes": int(batched_k_bytes),
|
||
"batched_v_cache_bytes": int(batched_v_bytes),
|
||
"batched_kv_cache_bytes": int(batched_k_bytes + batched_v_bytes),
|
||
"batched_decode_mask_bytes": int(batched_mask_bytes),
|
||
"private_kv_cache_bytes_reference": int(private_k_bytes + private_v_bytes),
|
||
"kv_padding_overhead_bytes": int((batched_k_bytes + batched_v_bytes) - (private_k_bytes + private_v_bytes)),
|
||
"total_bytes": int(total_bytes),
|
||
"total_mb": bytes_to_mb(total_bytes),
|
||
"xy_pos_shape": list(xy_pos.shape),
|
||
"batched_decode_mask_shape": None if batched_decode_attn_mask is None else list(batched_decode_attn_mask.shape),
|
||
"layer_k_cache_shape": list(batched_k_cache[0].shape),
|
||
}
|
||
|
||
|
||
def summarise_decode_outputs(
|
||
xy_dec: torch.Tensor,
|
||
next_k_cache: Sequence[torch.Tensor],
|
||
next_v_cache: Sequence[torch.Tensor],
|
||
) -> Dict[str, Any]:
|
||
xy_dec_bytes = tensor_nbytes(xy_dec)
|
||
next_k_bytes = tensor_list_nbytes(next_k_cache)
|
||
next_v_bytes = tensor_list_nbytes(next_v_cache)
|
||
total_bytes = xy_dec_bytes + next_k_bytes + next_v_bytes
|
||
return {
|
||
"xy_dec_bytes": int(xy_dec_bytes),
|
||
"next_k_cache_bytes": int(next_k_bytes),
|
||
"next_v_cache_bytes": int(next_v_bytes),
|
||
"next_kv_cache_bytes": int(next_k_bytes + next_v_bytes),
|
||
"total_bytes": int(total_bytes),
|
||
"total_mb": bytes_to_mb(total_bytes),
|
||
"xy_dec_shape": list(xy_dec.shape),
|
||
"layer_next_k_cache_shape": list(next_k_cache[0].shape),
|
||
}
|
||
|
||
|
||
def top_rankings(summary: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||
ranking = [
|
||
("request_state_total", summary["prepare_stage"]["request_state"]["total"]["total_bytes"]),
|
||
("prefill_batch_total", summary["prefill_batch"]["tensor_bytes"]["total_bytes"]),
|
||
("running_private_kv", summary["prefill_step"]["running_requests"]["totals"]["private_kv_cache_bytes"]),
|
||
("decode_batched_kv", summary["decode_batch"]["tensor_bytes"]["batched_kv_cache_bytes"]),
|
||
("decode_kv_padding_overhead", summary["decode_batch"]["tensor_bytes"]["kv_padding_overhead_bytes"]),
|
||
("decode_outputs_next_kv", summary["decode_outputs"]["tensor_bytes"]["next_kv_cache_bytes"]),
|
||
("prefill_attn_mask", summary["prefill_batch"]["tensor_bytes"]["prefill_attn_mask_bytes"]),
|
||
]
|
||
ranking.sort(key=lambda item: item[1], reverse=True)
|
||
return [{"name": name, "bytes": int(value), "mb": bytes_to_mb(int(value))} for name, value in ranking]
|
||
|
||
|
||
def synthesize_finished_item(tts: TTS, state: T2SRequestState, semantic_tokens: torch.Tensor) -> Tuple[int, np.ndarray]:
|
||
semantic_tokens = semantic_tokens.unsqueeze(0).unsqueeze(0).to(tts.configs.device)
|
||
phones = state.phones.unsqueeze(0).to(tts.configs.device)
|
||
audio_fragment = tts.synthesize_audio_request_local(
|
||
semantic_tokens=semantic_tokens,
|
||
phones=phones,
|
||
prompt_semantic=state.prompt_semantic,
|
||
prompt_phones=state.prompt_phones,
|
||
refer_spec=state.refer_spec,
|
||
raw_audio=state.raw_audio,
|
||
raw_sr=state.raw_sr,
|
||
speed=1.0,
|
||
sample_steps=32,
|
||
)
|
||
output_sr = tts.configs.sampling_rate if not tts.configs.use_vocoder else tts.vocoder_configs["sr"]
|
||
return tts.audio_postprocess(
|
||
audio=[[audio_fragment]],
|
||
sr=int(output_sr),
|
||
batch_index_list=None,
|
||
speed_factor=1.0,
|
||
split_bucket=False,
|
||
fragment_interval=0.0,
|
||
super_sampling=False,
|
||
)
|
||
|
||
|
||
def simulate_worker_end_to_end(
|
||
tts: TTS,
|
||
specs: Sequence[SchedulerRequestSpec],
|
||
max_steps: int,
|
||
rounds: int,
|
||
grad_mode: str = "default",
|
||
) -> Dict[str, Any]:
|
||
device = tts.configs.device
|
||
recorder = GlobalPeakRecorder(device)
|
||
recorder.record("after_model_load")
|
||
|
||
state_map: Dict[str, T2SRequestState] = {}
|
||
per_round: List[Dict[str, Any]] = []
|
||
|
||
for round_index in range(rounds):
|
||
grad_context = torch.inference_mode if grad_mode == "inference_mode" else contextlib.nullcontext
|
||
with grad_context():
|
||
states = [prepare_request_state(tts, spec) for spec in specs]
|
||
state_map = {state.request_id: state for state in states}
|
||
recorder.record(
|
||
"after_prepare_states",
|
||
round_index=int(round_index),
|
||
request_count=int(len(states)),
|
||
grad_mode=grad_mode,
|
||
)
|
||
|
||
pending = list(states)
|
||
running_requests: List[T2SRunningRequest] = []
|
||
round_events: List[Dict[str, Any]] = []
|
||
current_tick = 0
|
||
|
||
while pending or running_requests:
|
||
admitted = pending
|
||
pending = []
|
||
|
||
if admitted:
|
||
recorder.record(
|
||
"before_prefill",
|
||
round_index=int(round_index),
|
||
tick=int(current_tick),
|
||
admitted_count=int(len(admitted)),
|
||
running_count=int(len(running_requests)),
|
||
grad_mode=grad_mode,
|
||
)
|
||
with grad_context():
|
||
admitted_running, admitted_finished = run_prefill_step(tts.t2s_model.model, admitted, max_steps=max_steps)
|
||
recorder.record(
|
||
"after_prefill",
|
||
round_index=int(round_index),
|
||
tick=int(current_tick),
|
||
admitted_running_count=int(len(admitted_running)),
|
||
admitted_finished_count=int(len(admitted_finished)),
|
||
running_count=int(len(running_requests)),
|
||
grad_mode=grad_mode,
|
||
)
|
||
round_events.append(
|
||
{
|
||
"tick": int(current_tick),
|
||
"event": "prefill",
|
||
"admitted_count": int(len(admitted)),
|
||
"admitted_running_count": int(len(admitted_running)),
|
||
"admitted_finished_count": int(len(admitted_finished)),
|
||
}
|
||
)
|
||
for item in admitted_finished:
|
||
recorder.record(
|
||
"before_synth_prefill_finished",
|
||
round_index=int(round_index),
|
||
tick=int(current_tick),
|
||
running_count=int(len(running_requests)),
|
||
finished_request_id=item.request_id,
|
||
semantic_len=int(item.semantic_tokens.shape[0]),
|
||
grad_mode=grad_mode,
|
||
)
|
||
with grad_context():
|
||
sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens)
|
||
recorder.record(
|
||
"after_synth_prefill_finished",
|
||
round_index=int(round_index),
|
||
tick=int(current_tick),
|
||
running_count=int(len(running_requests)),
|
||
finished_request_id=item.request_id,
|
||
sample_rate=int(sample_rate),
|
||
audio_samples=int(audio_data.shape[0]),
|
||
grad_mode=grad_mode,
|
||
)
|
||
running_requests.extend(admitted_running)
|
||
recorder.record(
|
||
"after_extend_running",
|
||
round_index=int(round_index),
|
||
tick=int(current_tick),
|
||
running_count=int(len(running_requests)),
|
||
grad_mode=grad_mode,
|
||
)
|
||
|
||
if running_requests:
|
||
recorder.record(
|
||
"before_decode",
|
||
round_index=int(round_index),
|
||
tick=int(current_tick),
|
||
running_count=int(len(running_requests)),
|
||
grad_mode=grad_mode,
|
||
)
|
||
with grad_context():
|
||
running_requests, step_finished = run_decode_step_for_running(
|
||
tts.t2s_model.model,
|
||
running_requests,
|
||
max_steps=max_steps,
|
||
)
|
||
recorder.record(
|
||
"after_decode",
|
||
round_index=int(round_index),
|
||
tick=int(current_tick),
|
||
running_count=int(len(running_requests)),
|
||
finished_count=int(len(step_finished)),
|
||
grad_mode=grad_mode,
|
||
)
|
||
round_events.append(
|
||
{
|
||
"tick": int(current_tick),
|
||
"event": "decode",
|
||
"running_count_after_decode": int(len(running_requests)),
|
||
"finished_count": int(len(step_finished)),
|
||
}
|
||
)
|
||
for item in step_finished:
|
||
recorder.record(
|
||
"before_synth_decode_finished",
|
||
round_index=int(round_index),
|
||
tick=int(current_tick),
|
||
running_count=int(len(running_requests)),
|
||
finished_request_id=item.request_id,
|
||
semantic_len=int(item.semantic_tokens.shape[0]),
|
||
grad_mode=grad_mode,
|
||
)
|
||
with grad_context():
|
||
sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens)
|
||
recorder.record(
|
||
"after_synth_decode_finished",
|
||
round_index=int(round_index),
|
||
tick=int(current_tick),
|
||
running_count=int(len(running_requests)),
|
||
finished_request_id=item.request_id,
|
||
sample_rate=int(sample_rate),
|
||
audio_samples=int(audio_data.shape[0]),
|
||
grad_mode=grad_mode,
|
||
)
|
||
current_tick += 1
|
||
|
||
recorder.record(
|
||
"after_round_complete",
|
||
round_index=int(round_index),
|
||
running_count=0,
|
||
grad_mode=grad_mode,
|
||
)
|
||
per_round.append(
|
||
{
|
||
"round_index": int(round_index),
|
||
"events": round_events,
|
||
}
|
||
)
|
||
|
||
return {
|
||
"grad_mode": grad_mode,
|
||
"rounds": per_round,
|
||
"timeline": recorder.summary(),
|
||
}
|
||
|
||
|
||
def main() -> None:
|
||
args = parse_args()
|
||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
tts = load_pipeline(args.config)
|
||
model = tts.t2s_model.model
|
||
device = tts.configs.device
|
||
use_cuda = str(device).startswith("cuda") and torch.cuda.is_available()
|
||
set_seed(args.seed, use_cuda)
|
||
|
||
specs = load_request_specs(args)
|
||
if args.early_stop_num == -1:
|
||
for spec in specs:
|
||
spec.early_stop_num = int(tts.configs.hz * tts.configs.max_sec)
|
||
|
||
if args.warmup and specs:
|
||
warmup_spec = specs[:1]
|
||
_ = [prepare_request_state(tts, spec) for spec in warmup_spec]
|
||
gc.collect()
|
||
if use_cuda:
|
||
torch.cuda.empty_cache()
|
||
_sync_device(device)
|
||
|
||
states, prepare_mem = stage_run(device, lambda: [prepare_request_state(tts, spec) for spec in specs])
|
||
request_state_summary = summarise_state_tensors(states)
|
||
|
||
active_batch, prefill_batch_mem = stage_run(device, lambda: build_prefill_batch(model, states))
|
||
prefill_batch_tensor_summary = summarise_prefill_batch(active_batch)
|
||
|
||
prefill_result, prefill_step_mem = stage_run(device, lambda: run_prefill_step(model, states, max_steps=args.max_steps))
|
||
running_requests, finished_items = prefill_result
|
||
running_requests_summary = summarise_running_requests(running_requests)
|
||
finished_after_prefill_summary = [
|
||
{
|
||
"request_id": item.request_id,
|
||
"finish_idx": int(item.finish_idx),
|
||
"finish_reason": item.finish_reason,
|
||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||
}
|
||
for item in finished_items
|
||
]
|
||
|
||
if not running_requests:
|
||
raise RuntimeError(f"prefill 后没有 running requests,全部在首步结束: {[item.request_id for item in finished_items]}")
|
||
|
||
decode_batch_result, decode_batch_mem = stage_run(
|
||
device,
|
||
lambda: _build_decode_batch_from_running(model, running_requests),
|
||
)
|
||
xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = decode_batch_result
|
||
decode_batch_tensor_summary = summarise_decode_batch(
|
||
xy_pos,
|
||
batched_k_cache,
|
||
batched_v_cache,
|
||
batched_decode_attn_mask,
|
||
running_requests,
|
||
)
|
||
|
||
decode_out_result, decode_step_mem = stage_run(
|
||
device,
|
||
lambda: model.t2s_transformer.decode_next_token(
|
||
xy_pos,
|
||
batched_k_cache,
|
||
batched_v_cache,
|
||
batched_decode_attn_mask,
|
||
),
|
||
)
|
||
xy_dec, next_k_cache, next_v_cache = decode_out_result
|
||
decode_output_tensor_summary = summarise_decode_outputs(xy_dec, next_k_cache, next_v_cache)
|
||
del active_batch
|
||
del running_requests
|
||
del finished_items
|
||
del xy_pos
|
||
del batched_k_cache
|
||
del batched_v_cache
|
||
del batched_decode_attn_mask
|
||
del xy_dec
|
||
del next_k_cache
|
||
del next_v_cache
|
||
gc.collect()
|
||
if use_cuda:
|
||
_sync_device(device)
|
||
torch.cuda.empty_cache()
|
||
end_to_end_worker = simulate_worker_end_to_end(
|
||
tts=tts,
|
||
specs=specs,
|
||
max_steps=args.max_steps,
|
||
rounds=args.worker_rounds,
|
||
grad_mode=args.worker_grad_mode,
|
||
)
|
||
live_cuda_tensors_after_worker = snapshot_live_cuda_tensors()
|
||
worker_inference_mode = None
|
||
if args.compare_worker_grad_modes:
|
||
gc.collect()
|
||
if use_cuda:
|
||
_sync_device(device)
|
||
torch.cuda.empty_cache()
|
||
worker_inference_mode = simulate_worker_end_to_end(
|
||
tts=tts,
|
||
specs=specs,
|
||
max_steps=args.max_steps,
|
||
rounds=args.worker_rounds,
|
||
grad_mode="inference_mode",
|
||
)
|
||
|
||
summary = {
|
||
"meta": {
|
||
"scenario": args.scenario if args.request_manifest is None else "manifest",
|
||
"seed": int(args.seed),
|
||
"device": str(device),
|
||
"dtype": str(next(model.parameters()).dtype),
|
||
"request_count": int(len(specs)),
|
||
"num_layers": int(model.num_layers),
|
||
"num_heads": int(model.num_head),
|
||
"model_dim": int(model.model_dim),
|
||
"model_weights_mb": bytes_to_mb(model_nbytes(model)),
|
||
},
|
||
"loaded_module_weights": build_module_weight_summary(tts),
|
||
"requests": [
|
||
{
|
||
"request_id": spec.request_id,
|
||
"ref_audio_path": str(spec.ref_audio_path),
|
||
"prompt_text": spec.prompt_text,
|
||
"text": spec.text,
|
||
}
|
||
for spec in specs
|
||
],
|
||
"prepare_stage": {
|
||
"memory": prepare_mem,
|
||
"request_state": request_state_summary,
|
||
},
|
||
"prefill_batch": {
|
||
"memory": prefill_batch_mem,
|
||
"tensor_bytes": prefill_batch_tensor_summary,
|
||
},
|
||
"prefill_step": {
|
||
"memory": prefill_step_mem,
|
||
"running_requests": running_requests_summary,
|
||
"finished_after_prefill": finished_after_prefill_summary,
|
||
},
|
||
"decode_batch": {
|
||
"memory": decode_batch_mem,
|
||
"tensor_bytes": decode_batch_tensor_summary,
|
||
},
|
||
"decode_outputs": {
|
||
"memory": decode_step_mem,
|
||
"tensor_bytes": decode_output_tensor_summary,
|
||
},
|
||
"end_to_end_worker": end_to_end_worker,
|
||
"live_cuda_tensors_after_worker": live_cuda_tensors_after_worker,
|
||
"end_to_end_worker_inference_mode": worker_inference_mode,
|
||
}
|
||
summary["top_rankings"] = top_rankings(summary)
|
||
|
||
summary_path = args.output_dir / "t2s_memory_breakdown_summary.json"
|
||
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
|
||
print(json.dumps(summary["meta"], ensure_ascii=False, indent=2))
|
||
print("[top_rankings]")
|
||
for item in summary["top_rankings"]:
|
||
print(f"- {item['name']}: {item['mb']:.3f} MB")
|
||
print("[worker_peak]")
|
||
print(
|
||
json.dumps(
|
||
{
|
||
"peak_label": summary["end_to_end_worker"]["timeline"]["peak_label"],
|
||
"peak_allocated_mb": summary["end_to_end_worker"]["timeline"]["peak_allocated_mb"],
|
||
"peak_reserved_mb": summary["end_to_end_worker"]["timeline"]["peak_reserved_mb"],
|
||
},
|
||
ensure_ascii=False,
|
||
indent=2,
|
||
)
|
||
)
|
||
if worker_inference_mode is not None:
|
||
print("[worker_peak_inference_mode]")
|
||
print(
|
||
json.dumps(
|
||
{
|
||
"peak_label": worker_inference_mode["timeline"]["peak_label"],
|
||
"peak_allocated_mb": worker_inference_mode["timeline"]["peak_allocated_mb"],
|
||
"peak_reserved_mb": worker_inference_mode["timeline"]["peak_reserved_mb"],
|
||
},
|
||
ensure_ascii=False,
|
||
indent=2,
|
||
)
|
||
)
|
||
print(f"[summary] {summary_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|