mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-11 20:38:12 +08:00
251 lines
10 KiB
Python
251 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import subprocess
|
|
import threading
|
|
import time
|
|
import wave
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import httpx
|
|
|
|
ROOT_DIR = Path(__file__).resolve().parents[1]
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Benchmark api_v3 /tts_scheduler_submit concurrency and GPU memory.")
|
|
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:9880")
|
|
parser.add_argument("--endpoint", type=str, default="/tts_scheduler_submit")
|
|
parser.add_argument("--concurrency", type=int, required=True)
|
|
parser.add_argument("--timeout-sec", type=float, default=120.0)
|
|
parser.add_argument("--server-pid", type=int, default=None)
|
|
parser.add_argument("--poll-interval-sec", type=float, default=0.1)
|
|
parser.add_argument("--text-lang", type=str, default="zh")
|
|
parser.add_argument("--prompt-lang", type=str, default="zh")
|
|
parser.add_argument("--media-type", type=str, default="wav")
|
|
parser.add_argument("--top-k", type=int, default=15)
|
|
parser.add_argument("--top-p", type=float, default=1.0)
|
|
parser.add_argument("--temperature", type=float, default=1.0)
|
|
parser.add_argument("--repetition-penalty", type=float, default=1.35)
|
|
parser.add_argument("--sample-steps", type=int, default=32)
|
|
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_cn.txt")
|
|
parser.add_argument("--wav-dir", type=Path, default=ROOT_DIR / "testwav")
|
|
parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/api_v3_bench")
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_requests(args: argparse.Namespace) -> List[Dict[str, Any]]:
|
|
wav_paths_all = sorted(args.wav_dir.glob("*.wav"))
|
|
wav_paths: List[Path] = []
|
|
for wav_path in wav_paths_all:
|
|
with wave.open(str(wav_path), "rb") as handle:
|
|
duration = handle.getnframes() / float(handle.getframerate())
|
|
if 3.0 <= duration <= 10.0:
|
|
wav_paths.append(wav_path)
|
|
if not wav_paths:
|
|
raise FileNotFoundError(f"没有找到 3-10 秒合法 wav: {args.wav_dir}")
|
|
text_lines = [line.strip() for line in args.text_file.read_text(encoding="utf-8").splitlines() if line.strip()]
|
|
if not text_lines:
|
|
raise ValueError(f"没有找到有效文本行: {args.text_file}")
|
|
|
|
requests: List[Dict[str, Any]] = []
|
|
for index in range(args.concurrency):
|
|
wav_path = wav_paths[index % len(wav_paths)]
|
|
lab_path = wav_path.with_suffix(".lab")
|
|
if not lab_path.exists():
|
|
raise FileNotFoundError(f"缺少参考文本: {lab_path}")
|
|
requests.append(
|
|
{
|
|
"request_id": f"bench_{args.concurrency:03d}_{index:03d}",
|
|
"text": text_lines[index % len(text_lines)],
|
|
"text_lang": args.text_lang,
|
|
"ref_audio_path": str(wav_path),
|
|
"prompt_lang": args.prompt_lang,
|
|
"prompt_text": lab_path.read_text(encoding="utf-8").strip(),
|
|
"top_k": int(args.top_k),
|
|
"top_p": float(args.top_p),
|
|
"temperature": float(args.temperature),
|
|
"repetition_penalty": float(args.repetition_penalty),
|
|
"sample_steps": int(args.sample_steps),
|
|
"media_type": args.media_type,
|
|
"timeout_sec": float(args.timeout_sec),
|
|
}
|
|
)
|
|
return requests
|
|
|
|
|
|
class GpuMemoryPoller:
|
|
def __init__(self, server_pid: Optional[int], interval_sec: float):
|
|
self.server_pid = server_pid
|
|
self.interval_sec = interval_sec
|
|
self._stop = threading.Event()
|
|
self.samples: List[Dict[str, Any]] = []
|
|
self.thread: Optional[threading.Thread] = None
|
|
|
|
def _query_memory_mb(self) -> Optional[int]:
|
|
try:
|
|
result = subprocess.run(
|
|
[
|
|
"nvidia-smi",
|
|
"--query-compute-apps=pid,used_gpu_memory",
|
|
"--format=csv,noheader,nounits",
|
|
],
|
|
check=True,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
except Exception:
|
|
return None
|
|
total = 0
|
|
found = False
|
|
for line in result.stdout.splitlines():
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
parts = [item.strip() for item in line.split(",")]
|
|
if len(parts) != 2:
|
|
continue
|
|
try:
|
|
pid = int(parts[0])
|
|
used_mb = int(parts[1])
|
|
except ValueError:
|
|
continue
|
|
if self.server_pid is None or pid == self.server_pid:
|
|
total += used_mb
|
|
found = True
|
|
if self.server_pid is None:
|
|
return total
|
|
return total if found else 0
|
|
|
|
def _run(self) -> None:
|
|
while not self._stop.is_set():
|
|
used_mb = self._query_memory_mb()
|
|
self.samples.append({"ts": time.time(), "used_mb": used_mb})
|
|
self._stop.wait(self.interval_sec)
|
|
|
|
def start(self) -> None:
|
|
self.thread = threading.Thread(target=self._run, daemon=True)
|
|
self.thread.start()
|
|
|
|
def stop(self) -> None:
|
|
self._stop.set()
|
|
if self.thread is not None:
|
|
self.thread.join(timeout=2.0)
|
|
|
|
def summary(self) -> Dict[str, Any]:
|
|
valid = [item for item in self.samples if item["used_mb"] is not None]
|
|
peak = max(valid, key=lambda item: item["used_mb"]) if valid else None
|
|
first = valid[0] if valid else None
|
|
last = valid[-1] if valid else None
|
|
return {
|
|
"server_pid": self.server_pid,
|
|
"sample_count": int(len(self.samples)),
|
|
"start_used_mb": None if first is None else int(first["used_mb"]),
|
|
"peak_used_mb": None if peak is None else int(peak["used_mb"]),
|
|
"peak_delta_mb": None if peak is None or first is None else int(peak["used_mb"] - first["used_mb"]),
|
|
"end_used_mb": None if last is None else int(last["used_mb"]),
|
|
"peak_ts": None if peak is None else float(peak["ts"]),
|
|
"samples": self.samples,
|
|
}
|
|
|
|
|
|
async def submit_one(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
started = time.perf_counter()
|
|
try:
|
|
response = await client.post(url, json=payload)
|
|
elapsed_ms = (time.perf_counter() - started) * 1000.0
|
|
item = {
|
|
"request_id": payload["request_id"],
|
|
"status_code": int(response.status_code),
|
|
"elapsed_ms": float(elapsed_ms),
|
|
"content_type": response.headers.get("content-type"),
|
|
"audio_bytes": int(len(response.content)),
|
|
"headers": {key: value for key, value in response.headers.items() if key.lower().startswith("x-")},
|
|
}
|
|
if response.status_code != 200:
|
|
try:
|
|
item["error_body"] = response.json()
|
|
except Exception:
|
|
item["error_body"] = response.text
|
|
return item
|
|
except Exception as exc:
|
|
return {
|
|
"request_id": payload["request_id"],
|
|
"status_code": -1,
|
|
"elapsed_ms": float((time.perf_counter() - started) * 1000.0),
|
|
"exception": repr(exc),
|
|
}
|
|
|
|
|
|
async def run_benchmark(args: argparse.Namespace) -> Dict[str, Any]:
|
|
payloads = load_requests(args)
|
|
url = args.base_url.rstrip("/") + args.endpoint
|
|
poller = GpuMemoryPoller(server_pid=args.server_pid, interval_sec=args.poll_interval_sec)
|
|
|
|
limits = httpx.Limits(max_connections=args.concurrency, max_keepalive_connections=args.concurrency)
|
|
timeout = httpx.Timeout(connect=10.0, read=args.timeout_sec + 10.0, write=10.0, pool=10.0)
|
|
|
|
started = time.perf_counter()
|
|
poller.start()
|
|
try:
|
|
async with httpx.AsyncClient(limits=limits, timeout=timeout) as client:
|
|
results = await asyncio.gather(*[submit_one(client, url, payload) for payload in payloads])
|
|
finally:
|
|
poller.stop()
|
|
wall_ms = (time.perf_counter() - started) * 1000.0
|
|
|
|
ok_results = [item for item in results if item["status_code"] == 200]
|
|
failed_results = [item for item in results if item["status_code"] != 200]
|
|
request_total_ms = []
|
|
worker_total_ms = []
|
|
for item in ok_results:
|
|
headers = item.get("headers", {})
|
|
if "x-request-total-ms" in headers:
|
|
request_total_ms.append(float(headers["x-request-total-ms"]))
|
|
if "x-worker-total-ms" in headers:
|
|
worker_total_ms.append(float(headers["x-worker-total-ms"]))
|
|
|
|
return {
|
|
"concurrency": int(args.concurrency),
|
|
"server_pid": args.server_pid,
|
|
"request_count": int(len(payloads)),
|
|
"wall_ms": float(wall_ms),
|
|
"success_count": int(len(ok_results)),
|
|
"failure_count": int(len(failed_results)),
|
|
"request_total_ms_avg": float(sum(request_total_ms) / len(request_total_ms)) if request_total_ms else None,
|
|
"request_total_ms_max": float(max(request_total_ms)) if request_total_ms else None,
|
|
"worker_total_ms_avg": float(sum(worker_total_ms) / len(worker_total_ms)) if worker_total_ms else None,
|
|
"worker_total_ms_max": float(max(worker_total_ms)) if worker_total_ms else None,
|
|
"gpu_memory": poller.summary(),
|
|
"results": results,
|
|
}
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
output_dir = args.output_dir / f"concurrency_{args.concurrency:02d}"
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
summary = asyncio.run(run_benchmark(args))
|
|
summary_path = output_dir / "summary.json"
|
|
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
print(json.dumps({
|
|
"concurrency": summary["concurrency"],
|
|
"success_count": summary["success_count"],
|
|
"failure_count": summary["failure_count"],
|
|
"wall_ms": summary["wall_ms"],
|
|
"gpu_peak_used_mb": summary["gpu_memory"]["peak_used_mb"],
|
|
"request_total_ms_avg": summary["request_total_ms_avg"],
|
|
"request_total_ms_max": summary["request_total_ms_max"],
|
|
"summary_path": str(summary_path),
|
|
}, ensure_ascii=False, indent=2))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|