GPT-SoVITS/GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py
baicai-1145 6a822b28c3 Enhance TTS API with improved request handling and asynchronous processing
Refactor api_v2.py and api_v3.py to update sampling parameters and weight paths for better clarity and support for v3/v4 vocoders. Introduce new methods in PrepareCoordinator for handling empty text features and improve profiling capabilities. Additionally, update unified engine components to streamline audio processing and state management, enhancing overall performance and maintainability of the TTS system.
2026-03-12 01:27:19 +08:00

341 lines
17 KiB
Python

from __future__ import annotations
import asyncio
import time
import uuid
from io import BytesIO
from typing import Any, Dict, List
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, set_scheduler_seed
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerSubmitExecution
class EngineApiSchedulerFlow:
def __init__(self, api: Any) -> None:
self.api = api
def _build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]:
specs: List[SchedulerRequestSpec] = []
for index, payload in enumerate(request_items):
normalized = self.api._normalize_engine_request(
payload,
request_id=str(payload.get("request_id") or f"req_{index:03d}"),
error_prefix=f"request[{index}] 参数非法: ",
)
specs.append(normalized.to_scheduler_spec())
return specs
def _build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec:
normalized = self.api._normalize_engine_request(
payload,
request_id=(
payload.request_id
if isinstance(payload, NormalizedEngineRequest)
else str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}")
),
)
return normalized.to_scheduler_spec()
@staticmethod
def _summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]:
return [
{
"request_id": state.request_id,
"ready_step": int(state.ready_step),
"ref_audio_path": str(state.ref_audio_path),
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
"all_phone_len": int(state.all_phones.shape[0]),
"bert_len": int(state.all_bert_features.shape[-1]),
"norm_text": state.norm_text,
}
for state in states
]
@staticmethod
def _summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
return [
{
"request_id": item.request_id,
"semantic_len": int(item.semantic_tokens.shape[0]),
"finish_idx": int(item.finish_idx),
"finish_reason": item.finish_reason,
}
for item in items
]
async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
request_start = time.perf_counter()
set_scheduler_seed(seed)
normalized_requests: List[NormalizedEngineRequest] = []
for index, payload in enumerate(request_items):
normalized_requests.append(
self.api._normalize_engine_request(
payload,
request_id=str(payload.get("request_id") or f"req_{index:03d}"),
error_prefix=f"request[{index}] 参数非法: ",
)
)
specs = [normalized.to_scheduler_spec() for normalized in normalized_requests]
request_ids = [normalized.request_id for normalized in normalized_requests]
for normalized, spec in zip(normalized_requests, specs):
self.api._register_request_state(
request_id=normalized.request_id,
api_mode="scheduler_debug",
backend="scheduler_debug",
media_type=normalized.media_type,
response_streaming=False,
meta=self.api._build_request_meta(normalized.to_payload()),
)
self.api._update_request_state(normalized.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"})
self.api._update_request_state(normalized.request_id, EngineStatus.CPU_PREPARING, None)
prepare_started_at = time.perf_counter()
original_worker_max_steps = int(self.api.scheduler_worker.max_steps)
original_decode_max_steps = int(self.api.scheduler_worker.decode_executor.max_steps)
try:
self.api.scheduler_worker.max_steps = int(max_steps)
self.api.scheduler_worker.decode_executor.max_steps = int(max_steps)
prepared_payloads = await asyncio.gather(
*[
self.api._prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=time.perf_counter(),
engine_request_id=normalized.request_id,
)
for normalized, spec in zip(normalized_requests, specs)
]
)
except Exception as exc:
for request_id in request_ids:
self.api._fail_request_state(request_id, str(exc))
raise
finally:
self.api.scheduler_worker.max_steps = int(original_worker_max_steps)
self.api.scheduler_worker.decode_executor.max_steps = int(original_decode_max_steps)
prepare_finished_at = time.perf_counter()
prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0)
states = [payload[0] for payload in prepared_payloads]
for state in states:
self.api._update_request_state(
state.request_id,
EngineStatus.READY_FOR_PREFILL,
{
"prepare_profile": dict(state.prepare_profile),
"norm_text": state.norm_text,
"norm_prompt_text": state.norm_prompt_text,
},
)
decode_started_at = time.perf_counter()
try:
loop = asyncio.get_running_loop()
done_futures: List[asyncio.Future] = []
for normalized, state in zip(normalized_requests, states):
done_future = loop.create_future()
done_futures.append(done_future)
await self.api._enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=float(normalized.speed_factor),
sample_steps=int(normalized.sample_steps),
media_type=normalized.media_type,
super_sampling=bool(normalized.super_sampling),
prepare_wall_ms=float(state.prepare_profile.get("wall_total_ms", 0.0)),
prepare_profile_total_ms=float(state.prepare_profile.get("wall_total_ms", 0.0)),
done_loop=loop,
done_future=done_future,
engine_request_id=normalized.request_id,
timeout_sec=normalized.timeout_sec,
)
timeout_candidates = [float(item.timeout_sec) for item in normalized_requests if item.timeout_sec not in [None, ""]]
timeout_sec = max(timeout_candidates) if timeout_candidates else 60.0
jobs = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=float(timeout_sec)))
except Exception as exc:
for request_id in request_ids:
self.api._fail_request_state(request_id, str(exc))
raise
decode_finished_at = time.perf_counter()
decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0)
request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0)
request_profiles: List[Dict[str, Any]] = []
finished: List[Dict[str, Any]] = []
finish_reason_counts: Dict[str, int] = {}
total_semantic_len = 0
for state, job in zip(states, jobs):
if job.error is not None:
self.api._fail_request_state(state.request_id, str(job.error))
raise RuntimeError(str(job.error))
if job.result is None:
self.api._fail_request_state(state.request_id, "scheduler_debug finished without result")
raise RuntimeError(f"{state.request_id} finished without result")
job_result = dict(job.result)
request_profile = {
**job_result,
"backend": "scheduler_debug",
"backend_mode": "scheduler_debug",
"batch_request_count": int(len(states)),
"batch_prepare_wall_ms": float(prepare_batch_wall_ms),
"batch_decode_wall_ms": float(decode_batch_wall_ms),
"batch_request_total_ms": float(request_total_ms),
"prepare_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
"prepare_wall_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
"prepare_profile_total_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
"prepare_profile": dict(state.prepare_profile),
"norm_text": state.norm_text,
"norm_prompt_text": state.norm_prompt_text,
}
request_profiles.append({"request_id": state.request_id, "profile": dict(request_profile)})
self.api._merge_request_state_profile(state.request_id, request_profile)
semantic_len = int(job_result.get("semantic_len", 0))
finish_reason = str(job_result.get("finish_reason", "unknown"))
finished.append(
{
"request_id": state.request_id,
"semantic_len": semantic_len,
"finish_idx": int(job_result.get("finish_idx", job_result.get("decode_steps", 0))),
"finish_reason": finish_reason,
}
)
finish_reason_counts[finish_reason] = finish_reason_counts.get(finish_reason, 0) + 1
total_semantic_len += semantic_len
return SchedulerDebugExecution(
payload={
"message": "success",
"request_count": len(states),
"max_steps": int(max_steps),
"batch_profile": {
"request_count": int(len(states)),
"max_steps": int(max_steps),
"prepare_batch_wall_ms": float(prepare_batch_wall_ms),
"decode_batch_wall_ms": float(decode_batch_wall_ms),
"request_total_ms": float(request_total_ms),
"total_semantic_len": int(total_semantic_len),
"finish_reason_counts": finish_reason_counts,
},
"requests": self._summarize_scheduler_states(states),
"finished": finished,
"request_profiles": request_profiles,
"request_traces": self.api._collect_request_summaries(request_ids),
}
)
async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution:
request_start = time.perf_counter()
prepare_start = request_start
normalized = self.api._normalize_engine_request(
payload,
request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"),
)
spec = self._build_scheduler_submit_spec(normalized)
deadline_ts = None
timeout_sec = normalized.timeout_sec
if timeout_sec is not None:
try:
deadline_ts = request_start + float(timeout_sec)
except Exception:
deadline_ts = None
self.api._register_request_state(
request_id=spec.request_id,
api_mode="scheduler_submit",
backend="scheduler_v1",
media_type=normalized.media_type,
response_streaming=False,
deadline_ts=deadline_ts,
meta=self.api._build_request_meta(normalized.to_payload()),
)
self.api._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_submit"})
spec_ready_at = time.perf_counter()
prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0)
self.api._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, {"prepare_spec_build_ms": prepare_spec_build_ms})
try:
state, prepare_exec_started_at, prepare_exec_finished_at = await self.api._prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=spec_ready_at,
engine_request_id=spec.request_id,
)
except Exception as exc:
self.api._fail_request_state(spec.request_id, str(exc))
raise
prepare_wall_ms = max(0.0, (prepare_exec_finished_at - spec_ready_at) * 1000.0)
prepare_executor_queue_ms = max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0)
prepare_executor_run_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0)
prepare_profile = dict(state.prepare_profile)
prepare_profile_total_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms))
prepare_profile_wall_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms))
prepare_other_ms = max(0.0, prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_executor_run_ms)
self.api._update_request_state(
spec.request_id,
EngineStatus.READY_FOR_PREFILL,
{
"prepare_wall_ms": prepare_wall_ms,
"prepare_profile_total_ms": prepare_profile_total_ms,
"prepare_profile": prepare_profile,
},
)
api_after_prepare_start = time.perf_counter()
loop = asyncio.get_running_loop()
done_future = loop.create_future()
await self.api._enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=float(normalized.speed_factor),
sample_steps=int(normalized.sample_steps),
media_type=normalized.media_type,
super_sampling=bool(normalized.super_sampling),
prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=loop,
done_future=done_future,
engine_request_id=spec.request_id,
timeout_sec=normalized.timeout_sec,
)
api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0)
try:
job = await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0))
except Exception as exc:
self.api._fail_request_state(spec.request_id, str(exc))
raise
wait_return_at = time.perf_counter()
if job.error is not None:
raise RuntimeError(job.error)
if job.audio_data is None or job.sample_rate is None or job.result is None:
self.api._fail_request_state(spec.request_id, f"{job.request_id} finished without audio result")
raise RuntimeError(f"{job.request_id} finished without audio result")
pack_start = time.perf_counter()
audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue()
pack_end = time.perf_counter()
pack_ms = (pack_end - pack_start) * 1000.0
api_wait_result_ms = 0.0
if job.result_ready_time is not None:
api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0)
response_ready_at = time.perf_counter()
response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0)
submit_profile = self.api._build_scheduler_submit_profile(
backend="scheduler_v1",
request_start=request_start,
response_ready_at=response_ready_at,
audio_bytes=len(audio_data),
sample_rate=int(job.sample_rate),
prepare_spec_build_ms=prepare_spec_build_ms,
prepare_wall_ms=prepare_wall_ms,
prepare_executor_queue_ms=prepare_executor_queue_ms,
prepare_executor_run_ms=prepare_executor_run_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
prepare_profile_wall_ms=prepare_profile_wall_ms,
prepare_other_ms=prepare_other_ms,
engine_policy_wait_ms=float(job.result.get("engine_policy_wait_ms", 0.0)),
api_after_prepare_ms=api_after_prepare_ms,
api_wait_result_ms=api_wait_result_ms,
pack_ms=pack_ms,
response_overhead_ms=response_overhead_ms,
worker_profile=dict(job.result or {}),
)
headers = self.api._build_scheduler_submit_headers(
request_id=job.request_id,
media_type=job.media_type,
sample_rate=int(job.sample_rate),
profile=submit_profile,
)
self.api._merge_request_state_profile(
spec.request_id,
dict(submit_profile, response_headers_emitted=True),
)
return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=str(job.media_type), headers=headers)