mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-16 15:28:14 +08:00
Refactor T2S scheduler and inference handling to improve attention mask management and memory tracking. Update T2SRunningRequest and T2SActiveBatch classes to include optional key padding masks. Introduce new benchmarking tools for API performance and memory usage analysis, enhancing overall system efficiency.
This commit is contained in:
parent
dc37b0b9ef
commit
d245eb169c
@ -1781,6 +1781,7 @@ class TTS:
|
||||
|
||||
return audio
|
||||
|
||||
@torch.inference_mode()
|
||||
def synthesize_audio_request_local(
|
||||
self,
|
||||
semantic_tokens: torch.Tensor,
|
||||
|
||||
@ -70,7 +70,7 @@ class T2SRunningRequest:
|
||||
state: T2SRequestState
|
||||
y_sequence: torch.LongTensor
|
||||
prefix_len: int
|
||||
decode_attn_mask: torch.Tensor
|
||||
decode_attn_mask: Optional[torch.Tensor]
|
||||
k_cache: List[torch.Tensor]
|
||||
v_cache: List[torch.Tensor]
|
||||
step_idx: int
|
||||
@ -93,6 +93,7 @@ class T2SActiveBatch:
|
||||
y_sequences: List[torch.LongTensor]
|
||||
prefix_lens: torch.LongTensor
|
||||
xy_pos: torch.Tensor
|
||||
key_padding_mask: torch.Tensor
|
||||
prefill_attn_mask: torch.Tensor
|
||||
decode_attn_mask: Optional[torch.Tensor]
|
||||
k_cache: Optional[List[torch.Tensor]]
|
||||
@ -110,6 +111,7 @@ def normalize_sentence(text: str, language: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def prepare_request_state(
|
||||
tts: Any,
|
||||
spec: SchedulerRequestSpec,
|
||||
@ -212,6 +214,7 @@ def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device:
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch:
|
||||
x_items: List[torch.Tensor] = []
|
||||
y_pos_items: List[torch.Tensor] = []
|
||||
@ -240,22 +243,19 @@ def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SAct
|
||||
device = x_batch.device
|
||||
x_lens_tensor = torch.LongTensor(x_lens).to(device)
|
||||
prefix_lens_tensor = torch.LongTensor(prefix_lens).to(device)
|
||||
batch_size = len(states)
|
||||
src_len = max_x_len + max_prefix_len
|
||||
|
||||
x_padding_mask = make_pad_mask_left(x_lens_tensor, max_x_len)
|
||||
y_padding_mask = make_pad_mask_left(prefix_lens_tensor, max_prefix_len)
|
||||
padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1)
|
||||
key_padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1).bool()
|
||||
x_mask = F.pad(torch.zeros(max_x_len, max_x_len, dtype=torch.bool, device=device), (0, max_prefix_len), value=True)
|
||||
y_mask = F.pad(
|
||||
torch.triu(torch.ones(max_prefix_len, max_prefix_len, dtype=torch.bool, device=device), diagonal=1),
|
||||
(max_x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
causal_mask = torch.cat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(batch_size, 1, 1)
|
||||
padding_mask = padding_mask.view(batch_size, 1, src_len).repeat(1, src_len, 1)
|
||||
attn_mask = causal_mask.logical_or(padding_mask)
|
||||
attn_mask = attn_mask.unsqueeze(1).expand(-1, model.num_head, -1, -1).bool()
|
||||
causal_mask = torch.cat([x_mask, y_mask], dim=0).unsqueeze(0)
|
||||
attn_mask = causal_mask.logical_or(key_padding_mask.unsqueeze(1)).unsqueeze(1)
|
||||
|
||||
return T2SActiveBatch(
|
||||
request_ids=[state.request_id for state in states],
|
||||
@ -265,6 +265,7 @@ def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SAct
|
||||
y_sequences=y_sequences,
|
||||
prefix_lens=prefix_lens_tensor,
|
||||
xy_pos=xy_pos,
|
||||
key_padding_mask=key_padding_mask,
|
||||
prefill_attn_mask=attn_mask,
|
||||
decode_attn_mask=None,
|
||||
k_cache=None,
|
||||
@ -322,10 +323,11 @@ def _sample_per_request(
|
||||
finish_reason = "eos_argmax"
|
||||
|
||||
if finish_reason is not None:
|
||||
prefix_len = int(active_batch.prefix_lens[batch_index].item())
|
||||
finished_items.append(
|
||||
T2SFinishedItem(
|
||||
request_id=state.request_id,
|
||||
semantic_tokens=new_history[:-1].clone(),
|
||||
semantic_tokens=new_history[prefix_len:-1].clone(),
|
||||
finish_idx=step_idx,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
@ -346,11 +348,7 @@ def decode_one_step(
|
||||
xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt(
|
||||
active_batch.xy_pos, active_batch.prefill_attn_mask, None
|
||||
)
|
||||
active_batch.decode_attn_mask = F.pad(
|
||||
active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2),
|
||||
(0, 1),
|
||||
value=False,
|
||||
)
|
||||
active_batch.decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False)
|
||||
active_batch.prefill_done = True
|
||||
else:
|
||||
xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token(
|
||||
@ -411,6 +409,18 @@ def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor:
|
||||
return F.pad(mask, (pad_len, 0), value=True)
|
||||
|
||||
|
||||
def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> torch.Tensor:
|
||||
if running_request.decode_attn_mask is not None:
|
||||
return running_request.decode_attn_mask
|
||||
current_mask_len = running_request.k_cache[0].shape[1] + 1
|
||||
return torch.zeros(
|
||||
(1, 1, 1, current_mask_len),
|
||||
dtype=torch.bool,
|
||||
device=running_request.k_cache[0].device,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_prefill_step(
|
||||
model: Any,
|
||||
states: Sequence[T2SRequestState],
|
||||
@ -421,11 +431,9 @@ def run_prefill_step(
|
||||
|
||||
active_batch = build_prefill_batch(model, states)
|
||||
xy_dec, k_cache, v_cache = model.t2s_transformer.process_prompt(active_batch.xy_pos, active_batch.prefill_attn_mask, None)
|
||||
decode_attn_mask = F.pad(
|
||||
active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2),
|
||||
(0, 1),
|
||||
value=False,
|
||||
)
|
||||
decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False)
|
||||
if len(states) == 1 and not decode_attn_mask.any().item():
|
||||
decode_attn_mask = None
|
||||
logits = model.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
running_requests: List[T2SRunningRequest] = []
|
||||
@ -463,7 +471,7 @@ def run_prefill_step(
|
||||
finished_items.append(
|
||||
T2SFinishedItem(
|
||||
request_id=state.request_id,
|
||||
semantic_tokens=new_history[:-1].clone(),
|
||||
semantic_tokens=new_history[prefix_len:-1].clone(),
|
||||
finish_idx=0,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
@ -479,7 +487,11 @@ def run_prefill_step(
|
||||
state=state,
|
||||
y_sequence=new_history,
|
||||
prefix_len=prefix_len,
|
||||
decode_attn_mask=decode_attn_mask[batch_index : batch_index + 1].clone(),
|
||||
decode_attn_mask=(
|
||||
None
|
||||
if decode_attn_mask is None
|
||||
else decode_attn_mask[batch_index : batch_index + 1].clone()
|
||||
),
|
||||
k_cache=request_k_cache,
|
||||
v_cache=request_v_cache,
|
||||
step_idx=1,
|
||||
@ -492,10 +504,9 @@ def run_prefill_step(
|
||||
def _build_decode_batch_from_running(
|
||||
model: Any,
|
||||
running_requests: Sequence[T2SRunningRequest],
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], Optional[torch.Tensor]]:
|
||||
xy_pos = build_next_xy_pos(model, [item.y_sequence for item in running_requests])
|
||||
max_kv_len = max(item.k_cache[0].shape[1] for item in running_requests)
|
||||
max_mask_len = max(item.decode_attn_mask.shape[-1] for item in running_requests)
|
||||
num_layers = len(running_requests[0].k_cache)
|
||||
|
||||
batched_k_cache: List[torch.Tensor] = []
|
||||
@ -508,13 +519,19 @@ def _build_decode_batch_from_running(
|
||||
torch.cat([_pad_cache_left(item.v_cache[layer_index], max_kv_len) for item in running_requests], dim=0)
|
||||
)
|
||||
|
||||
batched_decode_attn_mask = torch.cat(
|
||||
[_pad_decode_mask_left(item.decode_attn_mask, max_mask_len) for item in running_requests],
|
||||
dim=0,
|
||||
)
|
||||
if all(item.decode_attn_mask is None for item in running_requests):
|
||||
batched_decode_attn_mask = None
|
||||
else:
|
||||
materialized_masks = [_materialize_decode_mask_for_request(item) for item in running_requests]
|
||||
max_mask_len = max(mask.shape[-1] for mask in materialized_masks)
|
||||
batched_decode_attn_mask = torch.cat(
|
||||
[_pad_decode_mask_left(mask, max_mask_len) for mask in materialized_masks],
|
||||
dim=0,
|
||||
)
|
||||
return xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_decode_step_for_running(
|
||||
model: Any,
|
||||
running_requests: Sequence[T2SRunningRequest],
|
||||
@ -568,7 +585,7 @@ def run_decode_step_for_running(
|
||||
finished_items.append(
|
||||
T2SFinishedItem(
|
||||
request_id=running_request.state.request_id,
|
||||
semantic_tokens=new_history[:-1].clone(),
|
||||
semantic_tokens=new_history[running_request.prefix_len:-1].clone(),
|
||||
finish_idx=current_idx,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
@ -578,12 +595,20 @@ def run_decode_step_for_running(
|
||||
real_next_kv_len = running_request.k_cache[0].shape[1] + 1
|
||||
request_k_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_k_cache]
|
||||
request_v_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_v_cache]
|
||||
if batched_decode_attn_mask is None:
|
||||
next_decode_attn_mask = None
|
||||
else:
|
||||
current_decode_mask_len = running_request.k_cache[0].shape[1] + 1
|
||||
current_decode_attn_mask = batched_decode_attn_mask[
|
||||
batch_index : batch_index + 1, :, :, -current_decode_mask_len:
|
||||
]
|
||||
next_decode_attn_mask = F.pad(current_decode_attn_mask, (0, 1), value=False)
|
||||
next_running.append(
|
||||
T2SRunningRequest(
|
||||
state=running_request.state,
|
||||
y_sequence=new_history,
|
||||
prefix_len=running_request.prefix_len,
|
||||
decode_attn_mask=F.pad(running_request.decode_attn_mask, (0, 1), value=False),
|
||||
decode_attn_mask=next_decode_attn_mask,
|
||||
k_cache=request_k_cache,
|
||||
v_cache=request_v_cache,
|
||||
step_idx=current_idx + 1,
|
||||
@ -593,6 +618,7 @@ def run_decode_step_for_running(
|
||||
return next_running, finished_items
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_scheduler_continuous(
|
||||
model: Any,
|
||||
states: Sequence[T2SRequestState],
|
||||
|
||||
250
tools/bench_api_v3_scheduler_submit.py
Normal file
250
tools/bench_api_v3_scheduler_submit.py
Normal file
@ -0,0 +1,250 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Benchmark api_v3 /tts_scheduler_submit concurrency and GPU memory.")
|
||||
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:9880")
|
||||
parser.add_argument("--endpoint", type=str, default="/tts_scheduler_submit")
|
||||
parser.add_argument("--concurrency", type=int, required=True)
|
||||
parser.add_argument("--timeout-sec", type=float, default=120.0)
|
||||
parser.add_argument("--server-pid", type=int, default=None)
|
||||
parser.add_argument("--poll-interval-sec", type=float, default=0.1)
|
||||
parser.add_argument("--text-lang", type=str, default="zh")
|
||||
parser.add_argument("--prompt-lang", type=str, default="zh")
|
||||
parser.add_argument("--media-type", type=str, default="wav")
|
||||
parser.add_argument("--top-k", type=int, default=15)
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.35)
|
||||
parser.add_argument("--sample-steps", type=int, default=32)
|
||||
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_cn.txt")
|
||||
parser.add_argument("--wav-dir", type=Path, default=ROOT_DIR / "testwav")
|
||||
parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/api_v3_bench")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_requests(args: argparse.Namespace) -> List[Dict[str, Any]]:
|
||||
wav_paths_all = sorted(args.wav_dir.glob("*.wav"))
|
||||
wav_paths: List[Path] = []
|
||||
for wav_path in wav_paths_all:
|
||||
with wave.open(str(wav_path), "rb") as handle:
|
||||
duration = handle.getnframes() / float(handle.getframerate())
|
||||
if 3.0 <= duration <= 10.0:
|
||||
wav_paths.append(wav_path)
|
||||
if not wav_paths:
|
||||
raise FileNotFoundError(f"没有找到 3-10 秒合法 wav: {args.wav_dir}")
|
||||
text_lines = [line.strip() for line in args.text_file.read_text(encoding="utf-8").splitlines() if line.strip()]
|
||||
if not text_lines:
|
||||
raise ValueError(f"没有找到有效文本行: {args.text_file}")
|
||||
|
||||
requests: List[Dict[str, Any]] = []
|
||||
for index in range(args.concurrency):
|
||||
wav_path = wav_paths[index % len(wav_paths)]
|
||||
lab_path = wav_path.with_suffix(".lab")
|
||||
if not lab_path.exists():
|
||||
raise FileNotFoundError(f"缺少参考文本: {lab_path}")
|
||||
requests.append(
|
||||
{
|
||||
"request_id": f"bench_{args.concurrency:03d}_{index:03d}",
|
||||
"text": text_lines[index % len(text_lines)],
|
||||
"text_lang": args.text_lang,
|
||||
"ref_audio_path": str(wav_path),
|
||||
"prompt_lang": args.prompt_lang,
|
||||
"prompt_text": lab_path.read_text(encoding="utf-8").strip(),
|
||||
"top_k": int(args.top_k),
|
||||
"top_p": float(args.top_p),
|
||||
"temperature": float(args.temperature),
|
||||
"repetition_penalty": float(args.repetition_penalty),
|
||||
"sample_steps": int(args.sample_steps),
|
||||
"media_type": args.media_type,
|
||||
"timeout_sec": float(args.timeout_sec),
|
||||
}
|
||||
)
|
||||
return requests
|
||||
|
||||
|
||||
class GpuMemoryPoller:
|
||||
def __init__(self, server_pid: Optional[int], interval_sec: float):
|
||||
self.server_pid = server_pid
|
||||
self.interval_sec = interval_sec
|
||||
self._stop = threading.Event()
|
||||
self.samples: List[Dict[str, Any]] = []
|
||||
self.thread: Optional[threading.Thread] = None
|
||||
|
||||
def _query_memory_mb(self) -> Optional[int]:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"--query-compute-apps=pid,used_gpu_memory",
|
||||
"--format=csv,noheader,nounits",
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
total = 0
|
||||
found = False
|
||||
for line in result.stdout.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = [item.strip() for item in line.split(",")]
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
try:
|
||||
pid = int(parts[0])
|
||||
used_mb = int(parts[1])
|
||||
except ValueError:
|
||||
continue
|
||||
if self.server_pid is None or pid == self.server_pid:
|
||||
total += used_mb
|
||||
found = True
|
||||
if self.server_pid is None:
|
||||
return total
|
||||
return total if found else 0
|
||||
|
||||
def _run(self) -> None:
|
||||
while not self._stop.is_set():
|
||||
used_mb = self._query_memory_mb()
|
||||
self.samples.append({"ts": time.time(), "used_mb": used_mb})
|
||||
self._stop.wait(self.interval_sec)
|
||||
|
||||
def start(self) -> None:
|
||||
self.thread = threading.Thread(target=self._run, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop.set()
|
||||
if self.thread is not None:
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
valid = [item for item in self.samples if item["used_mb"] is not None]
|
||||
peak = max(valid, key=lambda item: item["used_mb"]) if valid else None
|
||||
first = valid[0] if valid else None
|
||||
last = valid[-1] if valid else None
|
||||
return {
|
||||
"server_pid": self.server_pid,
|
||||
"sample_count": int(len(self.samples)),
|
||||
"start_used_mb": None if first is None else int(first["used_mb"]),
|
||||
"peak_used_mb": None if peak is None else int(peak["used_mb"]),
|
||||
"peak_delta_mb": None if peak is None or first is None else int(peak["used_mb"] - first["used_mb"]),
|
||||
"end_used_mb": None if last is None else int(last["used_mb"]),
|
||||
"peak_ts": None if peak is None else float(peak["ts"]),
|
||||
"samples": self.samples,
|
||||
}
|
||||
|
||||
|
||||
async def submit_one(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
started = time.perf_counter()
|
||||
try:
|
||||
response = await client.post(url, json=payload)
|
||||
elapsed_ms = (time.perf_counter() - started) * 1000.0
|
||||
item = {
|
||||
"request_id": payload["request_id"],
|
||||
"status_code": int(response.status_code),
|
||||
"elapsed_ms": float(elapsed_ms),
|
||||
"content_type": response.headers.get("content-type"),
|
||||
"audio_bytes": int(len(response.content)),
|
||||
"headers": {key: value for key, value in response.headers.items() if key.lower().startswith("x-")},
|
||||
}
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
item["error_body"] = response.json()
|
||||
except Exception:
|
||||
item["error_body"] = response.text
|
||||
return item
|
||||
except Exception as exc:
|
||||
return {
|
||||
"request_id": payload["request_id"],
|
||||
"status_code": -1,
|
||||
"elapsed_ms": float((time.perf_counter() - started) * 1000.0),
|
||||
"exception": repr(exc),
|
||||
}
|
||||
|
||||
|
||||
async def run_benchmark(args: argparse.Namespace) -> Dict[str, Any]:
|
||||
payloads = load_requests(args)
|
||||
url = args.base_url.rstrip("/") + args.endpoint
|
||||
poller = GpuMemoryPoller(server_pid=args.server_pid, interval_sec=args.poll_interval_sec)
|
||||
|
||||
limits = httpx.Limits(max_connections=args.concurrency, max_keepalive_connections=args.concurrency)
|
||||
timeout = httpx.Timeout(connect=10.0, read=args.timeout_sec + 10.0, write=10.0, pool=10.0)
|
||||
|
||||
started = time.perf_counter()
|
||||
poller.start()
|
||||
try:
|
||||
async with httpx.AsyncClient(limits=limits, timeout=timeout) as client:
|
||||
results = await asyncio.gather(*[submit_one(client, url, payload) for payload in payloads])
|
||||
finally:
|
||||
poller.stop()
|
||||
wall_ms = (time.perf_counter() - started) * 1000.0
|
||||
|
||||
ok_results = [item for item in results if item["status_code"] == 200]
|
||||
failed_results = [item for item in results if item["status_code"] != 200]
|
||||
request_total_ms = []
|
||||
worker_total_ms = []
|
||||
for item in ok_results:
|
||||
headers = item.get("headers", {})
|
||||
if "x-request-total-ms" in headers:
|
||||
request_total_ms.append(float(headers["x-request-total-ms"]))
|
||||
if "x-worker-total-ms" in headers:
|
||||
worker_total_ms.append(float(headers["x-worker-total-ms"]))
|
||||
|
||||
return {
|
||||
"concurrency": int(args.concurrency),
|
||||
"server_pid": args.server_pid,
|
||||
"request_count": int(len(payloads)),
|
||||
"wall_ms": float(wall_ms),
|
||||
"success_count": int(len(ok_results)),
|
||||
"failure_count": int(len(failed_results)),
|
||||
"request_total_ms_avg": float(sum(request_total_ms) / len(request_total_ms)) if request_total_ms else None,
|
||||
"request_total_ms_max": float(max(request_total_ms)) if request_total_ms else None,
|
||||
"worker_total_ms_avg": float(sum(worker_total_ms) / len(worker_total_ms)) if worker_total_ms else None,
|
||||
"worker_total_ms_max": float(max(worker_total_ms)) if worker_total_ms else None,
|
||||
"gpu_memory": poller.summary(),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
output_dir = args.output_dir / f"concurrency_{args.concurrency:02d}"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
summary = asyncio.run(run_benchmark(args))
|
||||
summary_path = output_dir / "summary.json"
|
||||
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(json.dumps({
|
||||
"concurrency": summary["concurrency"],
|
||||
"success_count": summary["success_count"],
|
||||
"failure_count": summary["failure_count"],
|
||||
"wall_ms": summary["wall_ms"],
|
||||
"gpu_peak_used_mb": summary["gpu_memory"]["peak_used_mb"],
|
||||
"request_total_ms_avg": summary["request_total_ms_avg"],
|
||||
"request_total_ms_max": summary["request_total_ms_max"],
|
||||
"summary_path": str(summary_path),
|
||||
}, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
887
tools/t2s_memory_breakdown.py
Normal file
887
tools/t2s_memory_breakdown.py
Normal file
@ -0,0 +1,887 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import contextlib
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT_DIR) not in sys.path:
|
||||
sys.path.append(str(ROOT_DIR))
|
||||
gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS"
|
||||
if str(gpt_sovits_dir) not in sys.path:
|
||||
sys.path.append(str(gpt_sovits_dir))
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402
|
||||
SchedulerRequestSpec,
|
||||
T2SRequestState,
|
||||
T2SRunningRequest,
|
||||
_build_decode_batch_from_running,
|
||||
build_prefill_batch,
|
||||
prepare_request_state,
|
||||
run_decode_step_for_running,
|
||||
run_prefill_step,
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Break down T2S CUDA memory by stage and tensor groups.")
|
||||
parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml")
|
||||
parser.add_argument("--request-manifest", type=Path, default=None)
|
||||
parser.add_argument("--scenario", type=str, default="auto4", choices=["auto4", "single"])
|
||||
parser.add_argument("--auto-count", type=int, default=4)
|
||||
parser.add_argument("--auto-wav-dir", type=Path, default=ROOT_DIR / "testwav")
|
||||
parser.add_argument("--auto-text-file", type=Path, default=ROOT_DIR / "test_cn.txt")
|
||||
parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav")
|
||||
parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。")
|
||||
parser.add_argument("--prompt-lang", type=str, default="zh")
|
||||
parser.add_argument("--text", type=str, default=None)
|
||||
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt")
|
||||
parser.add_argument("--text-lang", type=str, default="zh")
|
||||
parser.add_argument("--top-k", type=int, default=15)
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.35)
|
||||
parser.add_argument("--early-stop-num", type=int, default=-1)
|
||||
parser.add_argument("--max-steps", type=int, default=1500)
|
||||
parser.add_argument("--seed", type=int, default=1234)
|
||||
parser.add_argument("--warmup", action="store_true", default=False)
|
||||
parser.add_argument("--worker-rounds", type=int, default=1)
|
||||
parser.add_argument("--worker-grad-mode", type=str, default="default", choices=["default", "inference_mode"])
|
||||
parser.add_argument("--compare-worker-grad-modes", action="store_true", default=False)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=ROOT_DIR / "TEMP/t2s_memory_breakdown/run1",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def set_seed(seed: int, use_cuda: bool) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if use_cuda and torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def _sync_device(device: Any) -> None:
|
||||
try:
|
||||
device_str = str(device)
|
||||
if device_str.startswith("cuda") and torch.cuda.is_available():
|
||||
torch.cuda.synchronize(device)
|
||||
elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"):
|
||||
torch.mps.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def bytes_to_mb(num_bytes: int) -> float:
|
||||
return float(num_bytes) / (1024.0 * 1024.0)
|
||||
|
||||
|
||||
def tensor_nbytes(tensor: Optional[torch.Tensor]) -> int:
|
||||
if tensor is None:
|
||||
return 0
|
||||
return int(tensor.numel() * tensor.element_size())
|
||||
|
||||
|
||||
def tensor_list_nbytes(items: Sequence[torch.Tensor]) -> int:
|
||||
return int(sum(tensor_nbytes(item) for item in items))
|
||||
|
||||
|
||||
def model_nbytes(module: torch.nn.Module) -> int:
|
||||
total = 0
|
||||
for parameter in module.parameters():
|
||||
total += tensor_nbytes(parameter)
|
||||
for buffer in module.buffers():
|
||||
total += tensor_nbytes(buffer)
|
||||
return int(total)
|
||||
|
||||
|
||||
def build_module_weight_summary(tts: TTS) -> Dict[str, Any]:
|
||||
modules = {
|
||||
"t2s_model": tts.t2s_model,
|
||||
"t2s_core": tts.t2s_model.model if tts.t2s_model is not None else None,
|
||||
"vits_model": tts.vits_model,
|
||||
"bert_model": tts.bert_model,
|
||||
"cnhuhbert_model": tts.cnhuhbert_model,
|
||||
"vocoder": tts.vocoder,
|
||||
"sv_model": tts.sv_model,
|
||||
}
|
||||
by_module = {}
|
||||
total_bytes = 0
|
||||
for name, module in modules.items():
|
||||
module_bytes = model_nbytes(module) if module is not None else 0
|
||||
by_module[name] = {
|
||||
"bytes": int(module_bytes),
|
||||
"mb": bytes_to_mb(module_bytes),
|
||||
}
|
||||
total_bytes += module_bytes
|
||||
return {
|
||||
"by_module": by_module,
|
||||
"total_bytes": int(total_bytes),
|
||||
"total_mb": bytes_to_mb(total_bytes),
|
||||
}
|
||||
|
||||
|
||||
def snapshot_live_cuda_tensors(top_k: int = 40) -> Dict[str, Any]:
|
||||
storages: Dict[int, Dict[str, Any]] = {}
|
||||
tensor_views: List[Dict[str, Any]] = []
|
||||
for obj in gc.get_objects():
|
||||
try:
|
||||
tensor = None
|
||||
if torch.is_tensor(obj):
|
||||
tensor = obj
|
||||
elif hasattr(obj, "data") and torch.is_tensor(obj.data):
|
||||
tensor = obj.data
|
||||
if tensor is None or not tensor.is_cuda:
|
||||
continue
|
||||
storage = tensor.untyped_storage()
|
||||
storage_ptr = int(storage.data_ptr())
|
||||
if storage_ptr not in storages:
|
||||
storages[storage_ptr] = {
|
||||
"storage_ptr": storage_ptr,
|
||||
"storage_bytes": int(storage.nbytes()),
|
||||
"dtype": str(tensor.dtype),
|
||||
"shape": list(tensor.shape),
|
||||
"device": str(tensor.device),
|
||||
}
|
||||
tensor_views.append(
|
||||
{
|
||||
"shape": list(tensor.shape),
|
||||
"dtype": str(tensor.dtype),
|
||||
"bytes": tensor_nbytes(tensor),
|
||||
"device": str(tensor.device),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
storage_list = sorted(storages.values(), key=lambda item: item["storage_bytes"], reverse=True)
|
||||
tensor_views.sort(key=lambda item: item["bytes"], reverse=True)
|
||||
return {
|
||||
"unique_storage_count": int(len(storage_list)),
|
||||
"unique_storage_total_bytes": int(sum(item["storage_bytes"] for item in storage_list)),
|
||||
"unique_storage_total_mb": bytes_to_mb(sum(item["storage_bytes"] for item in storage_list)),
|
||||
"top_storages": storage_list[:top_k],
|
||||
"top_tensor_views": tensor_views[:top_k],
|
||||
}
|
||||
|
||||
|
||||
def build_single_spec(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
|
||||
text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8").strip()
|
||||
return [
|
||||
SchedulerRequestSpec(
|
||||
request_id="req_000",
|
||||
ref_audio_path=args.ref_audio,
|
||||
prompt_text=args.prompt_text,
|
||||
prompt_lang=args.prompt_lang,
|
||||
text=text,
|
||||
text_lang=args.text_lang,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
early_stop_num=args.early_stop_num,
|
||||
ready_step=0,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def build_auto_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
|
||||
wav_paths = sorted(args.auto_wav_dir.glob("*.wav"))[: args.auto_count]
|
||||
if len(wav_paths) < args.auto_count:
|
||||
raise ValueError(f"auto wav count不足,目录 {args.auto_wav_dir} 只有 {len(wav_paths)} 条 wav")
|
||||
text_lines = [line.strip() for line in args.auto_text_file.read_text(encoding="utf-8").splitlines() if line.strip()]
|
||||
if len(text_lines) < args.auto_count:
|
||||
raise ValueError(f"auto text lines不足,文件 {args.auto_text_file} 只有 {len(text_lines)} 行有效文本")
|
||||
specs: List[SchedulerRequestSpec] = []
|
||||
for index, wav_path in enumerate(wav_paths):
|
||||
lab_path = wav_path.with_suffix(".lab")
|
||||
if not lab_path.exists():
|
||||
raise FileNotFoundError(f"找不到参考文本 {lab_path}")
|
||||
specs.append(
|
||||
SchedulerRequestSpec(
|
||||
request_id=f"req_{index:03d}",
|
||||
ref_audio_path=wav_path,
|
||||
prompt_text=lab_path.read_text(encoding="utf-8").strip(),
|
||||
prompt_lang="zh",
|
||||
text=text_lines[index],
|
||||
text_lang=args.text_lang,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
early_stop_num=args.early_stop_num,
|
||||
ready_step=0,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
|
||||
def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
|
||||
if args.request_manifest is not None:
|
||||
payload = json.loads(args.request_manifest.read_text(encoding="utf-8"))
|
||||
raw_requests = payload["requests"] if isinstance(payload, dict) else payload
|
||||
specs: List[SchedulerRequestSpec] = []
|
||||
for index, item in enumerate(raw_requests):
|
||||
text = item.get("text")
|
||||
text_file = item.get("text_file")
|
||||
if text is None and text_file is None:
|
||||
raise ValueError(f"request[{index}] must provide text or text_file")
|
||||
if text is None:
|
||||
text = Path(text_file).read_text(encoding="utf-8").strip()
|
||||
specs.append(
|
||||
SchedulerRequestSpec(
|
||||
request_id=item.get("request_id", f"req_{index:03d}"),
|
||||
ref_audio_path=Path(item["ref_audio_path"]),
|
||||
prompt_text=item["prompt_text"],
|
||||
prompt_lang=item.get("prompt_lang", "zh"),
|
||||
text=text,
|
||||
text_lang=item.get("text_lang", "zh"),
|
||||
top_k=int(item.get("top_k", args.top_k)),
|
||||
top_p=float(item.get("top_p", args.top_p)),
|
||||
temperature=float(item.get("temperature", args.temperature)),
|
||||
repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)),
|
||||
early_stop_num=int(item.get("early_stop_num", args.early_stop_num)),
|
||||
ready_step=int(item.get("ready_step", 0)),
|
||||
)
|
||||
)
|
||||
return specs
|
||||
if args.scenario == "single":
|
||||
return build_single_spec(args)
|
||||
return build_auto_specs(args)
|
||||
|
||||
|
||||
def load_pipeline(config_path: Path) -> TTS:
|
||||
tts_config = TTS_Config(str(config_path))
|
||||
print(tts_config)
|
||||
return TTS(tts_config)
|
||||
|
||||
|
||||
def cuda_mem_snapshot(device: Any) -> Dict[str, float]:
|
||||
if not (str(device).startswith("cuda") and torch.cuda.is_available()):
|
||||
return {
|
||||
"allocated_mb": 0.0,
|
||||
"reserved_mb": 0.0,
|
||||
"max_allocated_mb": 0.0,
|
||||
"max_reserved_mb": 0.0,
|
||||
}
|
||||
_sync_device(device)
|
||||
return {
|
||||
"allocated_mb": bytes_to_mb(torch.cuda.memory_allocated(device)),
|
||||
"reserved_mb": bytes_to_mb(torch.cuda.memory_reserved(device)),
|
||||
"max_allocated_mb": bytes_to_mb(torch.cuda.max_memory_allocated(device)),
|
||||
"max_reserved_mb": bytes_to_mb(torch.cuda.max_memory_reserved(device)),
|
||||
}
|
||||
|
||||
|
||||
def stage_run(device: Any, fn) -> Tuple[Any, Dict[str, float]]:
|
||||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||
gc.collect()
|
||||
_sync_device(device)
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
before = cuda_mem_snapshot(device)
|
||||
started = time.perf_counter()
|
||||
result = fn()
|
||||
_sync_device(device)
|
||||
elapsed_ms = (time.perf_counter() - started) * 1000.0
|
||||
after = cuda_mem_snapshot(device)
|
||||
after["elapsed_ms"] = float(elapsed_ms)
|
||||
after["delta_allocated_mb"] = float(after["allocated_mb"] - before["allocated_mb"])
|
||||
after["delta_reserved_mb"] = float(after["reserved_mb"] - before["reserved_mb"])
|
||||
after["stage_peak_over_before_mb"] = float(max(after["max_allocated_mb"] - before["allocated_mb"], 0.0))
|
||||
return result, after
|
||||
|
||||
|
||||
class GlobalPeakRecorder:
|
||||
def __init__(self, device: Any):
|
||||
self.device = device
|
||||
self.checkpoints: List[Dict[str, Any]] = []
|
||||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||
gc.collect()
|
||||
_sync_device(device)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
|
||||
def record(self, label: str, **extra: Any) -> None:
|
||||
snapshot = cuda_mem_snapshot(self.device)
|
||||
snapshot["label"] = label
|
||||
snapshot.update(extra)
|
||||
self.checkpoints.append(snapshot)
|
||||
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
peak = max(self.checkpoints, key=lambda item: item["max_allocated_mb"]) if self.checkpoints else None
|
||||
return {
|
||||
"peak_allocated_mb": 0.0 if peak is None else float(peak["max_allocated_mb"]),
|
||||
"peak_reserved_mb": 0.0 if peak is None else float(peak["max_reserved_mb"]),
|
||||
"peak_label": None if peak is None else peak["label"],
|
||||
"checkpoints": self.checkpoints,
|
||||
}
|
||||
|
||||
|
||||
def summarise_state_tensors(states: Sequence[T2SRequestState]) -> Dict[str, Any]:
|
||||
per_request = []
|
||||
total = {
|
||||
"phones_bytes": 0,
|
||||
"prompt_phones_bytes": 0,
|
||||
"all_phones_bytes": 0,
|
||||
"all_bert_features_bytes": 0,
|
||||
"prompt_semantic_bytes": 0,
|
||||
"refer_spec_bytes": 0,
|
||||
"raw_audio_bytes": 0,
|
||||
"audio_16k_bytes": 0,
|
||||
}
|
||||
for state in states:
|
||||
spec_audio, audio_16k = state.refer_spec
|
||||
item = {
|
||||
"request_id": state.request_id,
|
||||
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
|
||||
"phones_len": int(state.phones.shape[0]),
|
||||
"all_phones_len": int(state.all_phones.shape[0]),
|
||||
"bert_frames": int(state.all_bert_features.shape[-1]),
|
||||
"phones_bytes": tensor_nbytes(state.phones),
|
||||
"prompt_phones_bytes": tensor_nbytes(state.prompt_phones),
|
||||
"all_phones_bytes": tensor_nbytes(state.all_phones),
|
||||
"all_bert_features_bytes": tensor_nbytes(state.all_bert_features),
|
||||
"prompt_semantic_bytes": tensor_nbytes(state.prompt_semantic),
|
||||
"refer_spec_bytes": tensor_nbytes(spec_audio),
|
||||
"audio_16k_bytes": tensor_nbytes(audio_16k),
|
||||
"raw_audio_bytes": tensor_nbytes(state.raw_audio),
|
||||
}
|
||||
for key in total:
|
||||
total[key] += int(item[key])
|
||||
per_request.append(item)
|
||||
total["total_bytes"] = int(sum(total.values()))
|
||||
total["total_mb"] = bytes_to_mb(total["total_bytes"])
|
||||
return {"per_request": per_request, "total": total}
|
||||
|
||||
|
||||
def summarise_prefill_batch(active_batch: Any) -> Dict[str, Any]:
|
||||
y_sequence_bytes = int(sum(tensor_nbytes(item) for item in active_batch.y_sequences))
|
||||
fields = {
|
||||
"x_bytes": tensor_nbytes(active_batch.x),
|
||||
"x_lens_bytes": tensor_nbytes(active_batch.x_lens),
|
||||
"prefix_lens_bytes": tensor_nbytes(active_batch.prefix_lens),
|
||||
"xy_pos_bytes": tensor_nbytes(active_batch.xy_pos),
|
||||
"key_padding_mask_bytes": tensor_nbytes(active_batch.key_padding_mask),
|
||||
"prefill_attn_mask_bytes": tensor_nbytes(active_batch.prefill_attn_mask),
|
||||
"y_sequence_bytes": y_sequence_bytes,
|
||||
}
|
||||
fields["total_bytes"] = int(sum(fields.values()))
|
||||
fields["total_mb"] = bytes_to_mb(fields["total_bytes"])
|
||||
fields["batch_size"] = int(len(active_batch.states))
|
||||
fields["max_x_len"] = int(active_batch.x.shape[1])
|
||||
fields["src_len"] = int(active_batch.xy_pos.shape[1])
|
||||
fields["prefill_attn_mask_shape"] = list(active_batch.prefill_attn_mask.shape)
|
||||
return fields
|
||||
|
||||
|
||||
def summarise_running_requests(running_requests: Sequence[T2SRunningRequest]) -> Dict[str, Any]:
|
||||
per_request = []
|
||||
total_private_k_bytes = 0
|
||||
total_private_v_bytes = 0
|
||||
total_decode_mask_bytes = 0
|
||||
total_y_sequence_bytes = 0
|
||||
for item in running_requests:
|
||||
k_bytes = tensor_list_nbytes(item.k_cache)
|
||||
v_bytes = tensor_list_nbytes(item.v_cache)
|
||||
mask_bytes = tensor_nbytes(item.decode_attn_mask)
|
||||
y_bytes = tensor_nbytes(item.y_sequence)
|
||||
total_private_k_bytes += k_bytes
|
||||
total_private_v_bytes += v_bytes
|
||||
total_decode_mask_bytes += mask_bytes
|
||||
total_y_sequence_bytes += y_bytes
|
||||
per_request.append(
|
||||
{
|
||||
"request_id": item.state.request_id,
|
||||
"step_idx": int(item.step_idx),
|
||||
"prefix_len": int(item.prefix_len),
|
||||
"history_len": int(item.y_sequence.shape[0]),
|
||||
"kv_len": int(item.k_cache[0].shape[1]),
|
||||
"k_cache_bytes": k_bytes,
|
||||
"v_cache_bytes": v_bytes,
|
||||
"decode_mask_bytes": mask_bytes,
|
||||
"y_sequence_bytes": y_bytes,
|
||||
}
|
||||
)
|
||||
total_bytes = total_private_k_bytes + total_private_v_bytes + total_decode_mask_bytes + total_y_sequence_bytes
|
||||
return {
|
||||
"per_request": per_request,
|
||||
"totals": {
|
||||
"private_k_cache_bytes": int(total_private_k_bytes),
|
||||
"private_v_cache_bytes": int(total_private_v_bytes),
|
||||
"private_kv_cache_bytes": int(total_private_k_bytes + total_private_v_bytes),
|
||||
"decode_mask_bytes": int(total_decode_mask_bytes),
|
||||
"y_sequence_bytes": int(total_y_sequence_bytes),
|
||||
"total_bytes": int(total_bytes),
|
||||
"total_mb": bytes_to_mb(total_bytes),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def summarise_decode_batch(
|
||||
xy_pos: torch.Tensor,
|
||||
batched_k_cache: Sequence[torch.Tensor],
|
||||
batched_v_cache: Sequence[torch.Tensor],
|
||||
batched_decode_attn_mask: Optional[torch.Tensor],
|
||||
running_requests: Sequence[T2SRunningRequest],
|
||||
) -> Dict[str, Any]:
|
||||
private_k_bytes = int(sum(tensor_list_nbytes(item.k_cache) for item in running_requests))
|
||||
private_v_bytes = int(sum(tensor_list_nbytes(item.v_cache) for item in running_requests))
|
||||
batched_k_bytes = tensor_list_nbytes(batched_k_cache)
|
||||
batched_v_bytes = tensor_list_nbytes(batched_v_cache)
|
||||
batched_mask_bytes = tensor_nbytes(batched_decode_attn_mask)
|
||||
xy_pos_bytes = tensor_nbytes(xy_pos)
|
||||
total_bytes = batched_k_bytes + batched_v_bytes + batched_mask_bytes + xy_pos_bytes
|
||||
return {
|
||||
"batch_size": int(len(running_requests)),
|
||||
"xy_pos_bytes": int(xy_pos_bytes),
|
||||
"batched_k_cache_bytes": int(batched_k_bytes),
|
||||
"batched_v_cache_bytes": int(batched_v_bytes),
|
||||
"batched_kv_cache_bytes": int(batched_k_bytes + batched_v_bytes),
|
||||
"batched_decode_mask_bytes": int(batched_mask_bytes),
|
||||
"private_kv_cache_bytes_reference": int(private_k_bytes + private_v_bytes),
|
||||
"kv_padding_overhead_bytes": int((batched_k_bytes + batched_v_bytes) - (private_k_bytes + private_v_bytes)),
|
||||
"total_bytes": int(total_bytes),
|
||||
"total_mb": bytes_to_mb(total_bytes),
|
||||
"xy_pos_shape": list(xy_pos.shape),
|
||||
"batched_decode_mask_shape": None if batched_decode_attn_mask is None else list(batched_decode_attn_mask.shape),
|
||||
"layer_k_cache_shape": list(batched_k_cache[0].shape),
|
||||
}
|
||||
|
||||
|
||||
def summarise_decode_outputs(
|
||||
xy_dec: torch.Tensor,
|
||||
next_k_cache: Sequence[torch.Tensor],
|
||||
next_v_cache: Sequence[torch.Tensor],
|
||||
) -> Dict[str, Any]:
|
||||
xy_dec_bytes = tensor_nbytes(xy_dec)
|
||||
next_k_bytes = tensor_list_nbytes(next_k_cache)
|
||||
next_v_bytes = tensor_list_nbytes(next_v_cache)
|
||||
total_bytes = xy_dec_bytes + next_k_bytes + next_v_bytes
|
||||
return {
|
||||
"xy_dec_bytes": int(xy_dec_bytes),
|
||||
"next_k_cache_bytes": int(next_k_bytes),
|
||||
"next_v_cache_bytes": int(next_v_bytes),
|
||||
"next_kv_cache_bytes": int(next_k_bytes + next_v_bytes),
|
||||
"total_bytes": int(total_bytes),
|
||||
"total_mb": bytes_to_mb(total_bytes),
|
||||
"xy_dec_shape": list(xy_dec.shape),
|
||||
"layer_next_k_cache_shape": list(next_k_cache[0].shape),
|
||||
}
|
||||
|
||||
|
||||
def top_rankings(summary: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
ranking = [
|
||||
("request_state_total", summary["prepare_stage"]["request_state"]["total"]["total_bytes"]),
|
||||
("prefill_batch_total", summary["prefill_batch"]["tensor_bytes"]["total_bytes"]),
|
||||
("running_private_kv", summary["prefill_step"]["running_requests"]["totals"]["private_kv_cache_bytes"]),
|
||||
("decode_batched_kv", summary["decode_batch"]["tensor_bytes"]["batched_kv_cache_bytes"]),
|
||||
("decode_kv_padding_overhead", summary["decode_batch"]["tensor_bytes"]["kv_padding_overhead_bytes"]),
|
||||
("decode_outputs_next_kv", summary["decode_outputs"]["tensor_bytes"]["next_kv_cache_bytes"]),
|
||||
("prefill_attn_mask", summary["prefill_batch"]["tensor_bytes"]["prefill_attn_mask_bytes"]),
|
||||
]
|
||||
ranking.sort(key=lambda item: item[1], reverse=True)
|
||||
return [{"name": name, "bytes": int(value), "mb": bytes_to_mb(int(value))} for name, value in ranking]
|
||||
|
||||
|
||||
def synthesize_finished_item(tts: TTS, state: T2SRequestState, semantic_tokens: torch.Tensor) -> Tuple[int, np.ndarray]:
|
||||
semantic_tokens = semantic_tokens.unsqueeze(0).unsqueeze(0).to(tts.configs.device)
|
||||
phones = state.phones.unsqueeze(0).to(tts.configs.device)
|
||||
audio_fragment = tts.synthesize_audio_request_local(
|
||||
semantic_tokens=semantic_tokens,
|
||||
phones=phones,
|
||||
prompt_semantic=state.prompt_semantic,
|
||||
prompt_phones=state.prompt_phones,
|
||||
refer_spec=state.refer_spec,
|
||||
raw_audio=state.raw_audio,
|
||||
raw_sr=state.raw_sr,
|
||||
speed=1.0,
|
||||
sample_steps=32,
|
||||
)
|
||||
output_sr = tts.configs.sampling_rate if not tts.configs.use_vocoder else tts.vocoder_configs["sr"]
|
||||
return tts.audio_postprocess(
|
||||
audio=[[audio_fragment]],
|
||||
sr=int(output_sr),
|
||||
batch_index_list=None,
|
||||
speed_factor=1.0,
|
||||
split_bucket=False,
|
||||
fragment_interval=0.0,
|
||||
super_sampling=False,
|
||||
)
|
||||
|
||||
|
||||
def simulate_worker_end_to_end(
|
||||
tts: TTS,
|
||||
specs: Sequence[SchedulerRequestSpec],
|
||||
max_steps: int,
|
||||
rounds: int,
|
||||
grad_mode: str = "default",
|
||||
) -> Dict[str, Any]:
|
||||
device = tts.configs.device
|
||||
recorder = GlobalPeakRecorder(device)
|
||||
recorder.record("after_model_load")
|
||||
|
||||
state_map: Dict[str, T2SRequestState] = {}
|
||||
per_round: List[Dict[str, Any]] = []
|
||||
|
||||
for round_index in range(rounds):
|
||||
grad_context = torch.inference_mode if grad_mode == "inference_mode" else contextlib.nullcontext
|
||||
with grad_context():
|
||||
states = [prepare_request_state(tts, spec) for spec in specs]
|
||||
state_map = {state.request_id: state for state in states}
|
||||
recorder.record(
|
||||
"after_prepare_states",
|
||||
round_index=int(round_index),
|
||||
request_count=int(len(states)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
|
||||
pending = list(states)
|
||||
running_requests: List[T2SRunningRequest] = []
|
||||
round_events: List[Dict[str, Any]] = []
|
||||
current_tick = 0
|
||||
|
||||
while pending or running_requests:
|
||||
admitted = pending
|
||||
pending = []
|
||||
|
||||
if admitted:
|
||||
recorder.record(
|
||||
"before_prefill",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
admitted_count=int(len(admitted)),
|
||||
running_count=int(len(running_requests)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
with grad_context():
|
||||
admitted_running, admitted_finished = run_prefill_step(tts.t2s_model.model, admitted, max_steps=max_steps)
|
||||
recorder.record(
|
||||
"after_prefill",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
admitted_running_count=int(len(admitted_running)),
|
||||
admitted_finished_count=int(len(admitted_finished)),
|
||||
running_count=int(len(running_requests)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
round_events.append(
|
||||
{
|
||||
"tick": int(current_tick),
|
||||
"event": "prefill",
|
||||
"admitted_count": int(len(admitted)),
|
||||
"admitted_running_count": int(len(admitted_running)),
|
||||
"admitted_finished_count": int(len(admitted_finished)),
|
||||
}
|
||||
)
|
||||
for item in admitted_finished:
|
||||
recorder.record(
|
||||
"before_synth_prefill_finished",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
finished_request_id=item.request_id,
|
||||
semantic_len=int(item.semantic_tokens.shape[0]),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
with grad_context():
|
||||
sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens)
|
||||
recorder.record(
|
||||
"after_synth_prefill_finished",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
finished_request_id=item.request_id,
|
||||
sample_rate=int(sample_rate),
|
||||
audio_samples=int(audio_data.shape[0]),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
running_requests.extend(admitted_running)
|
||||
recorder.record(
|
||||
"after_extend_running",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
|
||||
if running_requests:
|
||||
recorder.record(
|
||||
"before_decode",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
with grad_context():
|
||||
running_requests, step_finished = run_decode_step_for_running(
|
||||
tts.t2s_model.model,
|
||||
running_requests,
|
||||
max_steps=max_steps,
|
||||
)
|
||||
recorder.record(
|
||||
"after_decode",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
finished_count=int(len(step_finished)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
round_events.append(
|
||||
{
|
||||
"tick": int(current_tick),
|
||||
"event": "decode",
|
||||
"running_count_after_decode": int(len(running_requests)),
|
||||
"finished_count": int(len(step_finished)),
|
||||
}
|
||||
)
|
||||
for item in step_finished:
|
||||
recorder.record(
|
||||
"before_synth_decode_finished",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
finished_request_id=item.request_id,
|
||||
semantic_len=int(item.semantic_tokens.shape[0]),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
with grad_context():
|
||||
sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens)
|
||||
recorder.record(
|
||||
"after_synth_decode_finished",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
finished_request_id=item.request_id,
|
||||
sample_rate=int(sample_rate),
|
||||
audio_samples=int(audio_data.shape[0]),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
current_tick += 1
|
||||
|
||||
recorder.record(
|
||||
"after_round_complete",
|
||||
round_index=int(round_index),
|
||||
running_count=0,
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
per_round.append(
|
||||
{
|
||||
"round_index": int(round_index),
|
||||
"events": round_events,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"grad_mode": grad_mode,
|
||||
"rounds": per_round,
|
||||
"timeline": recorder.summary(),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tts = load_pipeline(args.config)
|
||||
model = tts.t2s_model.model
|
||||
device = tts.configs.device
|
||||
use_cuda = str(device).startswith("cuda") and torch.cuda.is_available()
|
||||
set_seed(args.seed, use_cuda)
|
||||
|
||||
specs = load_request_specs(args)
|
||||
if args.early_stop_num == -1:
|
||||
for spec in specs:
|
||||
spec.early_stop_num = int(tts.configs.hz * tts.configs.max_sec)
|
||||
|
||||
if args.warmup and specs:
|
||||
warmup_spec = specs[:1]
|
||||
_ = [prepare_request_state(tts, spec) for spec in warmup_spec]
|
||||
gc.collect()
|
||||
if use_cuda:
|
||||
torch.cuda.empty_cache()
|
||||
_sync_device(device)
|
||||
|
||||
states, prepare_mem = stage_run(device, lambda: [prepare_request_state(tts, spec) for spec in specs])
|
||||
request_state_summary = summarise_state_tensors(states)
|
||||
|
||||
active_batch, prefill_batch_mem = stage_run(device, lambda: build_prefill_batch(model, states))
|
||||
prefill_batch_tensor_summary = summarise_prefill_batch(active_batch)
|
||||
|
||||
prefill_result, prefill_step_mem = stage_run(device, lambda: run_prefill_step(model, states, max_steps=args.max_steps))
|
||||
running_requests, finished_items = prefill_result
|
||||
running_requests_summary = summarise_running_requests(running_requests)
|
||||
finished_after_prefill_summary = [
|
||||
{
|
||||
"request_id": item.request_id,
|
||||
"finish_idx": int(item.finish_idx),
|
||||
"finish_reason": item.finish_reason,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
}
|
||||
for item in finished_items
|
||||
]
|
||||
|
||||
if not running_requests:
|
||||
raise RuntimeError(f"prefill 后没有 running requests,全部在首步结束: {[item.request_id for item in finished_items]}")
|
||||
|
||||
decode_batch_result, decode_batch_mem = stage_run(
|
||||
device,
|
||||
lambda: _build_decode_batch_from_running(model, running_requests),
|
||||
)
|
||||
xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = decode_batch_result
|
||||
decode_batch_tensor_summary = summarise_decode_batch(
|
||||
xy_pos,
|
||||
batched_k_cache,
|
||||
batched_v_cache,
|
||||
batched_decode_attn_mask,
|
||||
running_requests,
|
||||
)
|
||||
|
||||
decode_out_result, decode_step_mem = stage_run(
|
||||
device,
|
||||
lambda: model.t2s_transformer.decode_next_token(
|
||||
xy_pos,
|
||||
batched_k_cache,
|
||||
batched_v_cache,
|
||||
batched_decode_attn_mask,
|
||||
),
|
||||
)
|
||||
xy_dec, next_k_cache, next_v_cache = decode_out_result
|
||||
decode_output_tensor_summary = summarise_decode_outputs(xy_dec, next_k_cache, next_v_cache)
|
||||
del active_batch
|
||||
del running_requests
|
||||
del finished_items
|
||||
del xy_pos
|
||||
del batched_k_cache
|
||||
del batched_v_cache
|
||||
del batched_decode_attn_mask
|
||||
del xy_dec
|
||||
del next_k_cache
|
||||
del next_v_cache
|
||||
gc.collect()
|
||||
if use_cuda:
|
||||
_sync_device(device)
|
||||
torch.cuda.empty_cache()
|
||||
end_to_end_worker = simulate_worker_end_to_end(
|
||||
tts=tts,
|
||||
specs=specs,
|
||||
max_steps=args.max_steps,
|
||||
rounds=args.worker_rounds,
|
||||
grad_mode=args.worker_grad_mode,
|
||||
)
|
||||
live_cuda_tensors_after_worker = snapshot_live_cuda_tensors()
|
||||
worker_inference_mode = None
|
||||
if args.compare_worker_grad_modes:
|
||||
gc.collect()
|
||||
if use_cuda:
|
||||
_sync_device(device)
|
||||
torch.cuda.empty_cache()
|
||||
worker_inference_mode = simulate_worker_end_to_end(
|
||||
tts=tts,
|
||||
specs=specs,
|
||||
max_steps=args.max_steps,
|
||||
rounds=args.worker_rounds,
|
||||
grad_mode="inference_mode",
|
||||
)
|
||||
|
||||
summary = {
|
||||
"meta": {
|
||||
"scenario": args.scenario if args.request_manifest is None else "manifest",
|
||||
"seed": int(args.seed),
|
||||
"device": str(device),
|
||||
"dtype": str(next(model.parameters()).dtype),
|
||||
"request_count": int(len(specs)),
|
||||
"num_layers": int(model.num_layers),
|
||||
"num_heads": int(model.num_head),
|
||||
"model_dim": int(model.model_dim),
|
||||
"model_weights_mb": bytes_to_mb(model_nbytes(model)),
|
||||
},
|
||||
"loaded_module_weights": build_module_weight_summary(tts),
|
||||
"requests": [
|
||||
{
|
||||
"request_id": spec.request_id,
|
||||
"ref_audio_path": str(spec.ref_audio_path),
|
||||
"prompt_text": spec.prompt_text,
|
||||
"text": spec.text,
|
||||
}
|
||||
for spec in specs
|
||||
],
|
||||
"prepare_stage": {
|
||||
"memory": prepare_mem,
|
||||
"request_state": request_state_summary,
|
||||
},
|
||||
"prefill_batch": {
|
||||
"memory": prefill_batch_mem,
|
||||
"tensor_bytes": prefill_batch_tensor_summary,
|
||||
},
|
||||
"prefill_step": {
|
||||
"memory": prefill_step_mem,
|
||||
"running_requests": running_requests_summary,
|
||||
"finished_after_prefill": finished_after_prefill_summary,
|
||||
},
|
||||
"decode_batch": {
|
||||
"memory": decode_batch_mem,
|
||||
"tensor_bytes": decode_batch_tensor_summary,
|
||||
},
|
||||
"decode_outputs": {
|
||||
"memory": decode_step_mem,
|
||||
"tensor_bytes": decode_output_tensor_summary,
|
||||
},
|
||||
"end_to_end_worker": end_to_end_worker,
|
||||
"live_cuda_tensors_after_worker": live_cuda_tensors_after_worker,
|
||||
"end_to_end_worker_inference_mode": worker_inference_mode,
|
||||
}
|
||||
summary["top_rankings"] = top_rankings(summary)
|
||||
|
||||
summary_path = args.output_dir / "t2s_memory_breakdown_summary.json"
|
||||
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
print(json.dumps(summary["meta"], ensure_ascii=False, indent=2))
|
||||
print("[top_rankings]")
|
||||
for item in summary["top_rankings"]:
|
||||
print(f"- {item['name']}: {item['mb']:.3f} MB")
|
||||
print("[worker_peak]")
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"peak_label": summary["end_to_end_worker"]["timeline"]["peak_label"],
|
||||
"peak_allocated_mb": summary["end_to_end_worker"]["timeline"]["peak_allocated_mb"],
|
||||
"peak_reserved_mb": summary["end_to_end_worker"]["timeline"]["peak_reserved_mb"],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
if worker_inference_mode is not None:
|
||||
print("[worker_peak_inference_mode]")
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"peak_label": worker_inference_mode["timeline"]["peak_label"],
|
||||
"peak_allocated_mb": worker_inference_mode["timeline"]["peak_allocated_mb"],
|
||||
"peak_reserved_mb": worker_inference_mode["timeline"]["peak_reserved_mb"],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
print(f"[summary] {summary_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
x
Reference in New Issue
Block a user