diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine.py b/GPT_SoVITS/TTS_infer_pack/unified_engine.py index 9b56199a..24e1c98b 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine.py @@ -1,2756 +1,15 @@ from __future__ import annotations -import asyncio import os -import signal -import subprocess -import sys -import threading -import time -import uuid -import wave -from collections import deque -from dataclasses import dataclass, field -from io import BytesIO -from pathlib import Path -from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Sequence, Tuple, Union - -import numpy as np -import soundfile as sf -import torch +from typing import Sequence from GPT_SoVITS.TTS_infer_pack.TTS import TTS -from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( - SchedulerRequestSpec, - T2SActiveBatch, - T2SFinishedItem, - T2SRequestState, - decode_one_step, - merge_active_batches, - run_prefill_active_batch, - run_scheduler_continuous, -) +from GPT_SoVITS.TTS_infer_pack.unified_engine_builder import EngineCompositionBuilder +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeControlCallbacks +from GPT_SoVITS.TTS_infer_pack.unified_engine_delegates import EngineApiDelegates, EngineBridgeDelegates, EngineRuntimeDelegates -@dataclass -class RuntimeControlCallbacks: - restart: Callable[[], None] | None = None - exit: Callable[[], None] | None = None - - -@dataclass -class DefaultReferenceState: - ref_audio_path: str | None = None - updated_at: float = 0.0 - - -class ReferenceRegistry: - def __init__(self) -> None: - self._lock = threading.Lock() - self._state = DefaultReferenceState() - - def set_default(self, ref_audio_path: str) -> DefaultReferenceState: - with self._lock: - self._state = DefaultReferenceState(ref_audio_path=str(ref_audio_path), updated_at=time.time()) - return self._state - - def clear(self) -> DefaultReferenceState: - with self._lock: - self._state = DefaultReferenceState() - return self._state - - def get_default(self) -> DefaultReferenceState: - with self._lock: - return DefaultReferenceState( - ref_audio_path=self._state.ref_audio_path, - updated_at=self._state.updated_at, - ) - - -@dataclass -class ModelRegistryState: - t2s_weights_path: str - vits_weights_path: str - generation: int = 0 - t2s_generation: int = 0 - vits_generation: int = 0 - updated_at: float = field(default_factory=time.time) - - -class ModelRegistry: - def __init__(self, t2s_weights_path: str, vits_weights_path: str) -> None: - self._lock = threading.Lock() - self._state = ModelRegistryState( - t2s_weights_path=str(t2s_weights_path), - vits_weights_path=str(vits_weights_path), - ) - - def snapshot(self) -> ModelRegistryState: - with self._lock: - return ModelRegistryState( - t2s_weights_path=self._state.t2s_weights_path, - vits_weights_path=self._state.vits_weights_path, - generation=self._state.generation, - t2s_generation=self._state.t2s_generation, - vits_generation=self._state.vits_generation, - updated_at=self._state.updated_at, - ) - - def mark_t2s_reload(self, weights_path: str) -> ModelRegistryState: - with self._lock: - self._state.t2s_weights_path = str(weights_path) - self._state.generation += 1 - self._state.t2s_generation += 1 - self._state.updated_at = time.time() - return ModelRegistryState( - t2s_weights_path=self._state.t2s_weights_path, - vits_weights_path=self._state.vits_weights_path, - generation=self._state.generation, - t2s_generation=self._state.t2s_generation, - vits_generation=self._state.vits_generation, - updated_at=self._state.updated_at, - ) - - def mark_vits_reload(self, weights_path: str) -> ModelRegistryState: - with self._lock: - self._state.vits_weights_path = str(weights_path) - self._state.generation += 1 - self._state.vits_generation += 1 - self._state.updated_at = time.time() - return ModelRegistryState( - t2s_weights_path=self._state.t2s_weights_path, - vits_weights_path=self._state.vits_weights_path, - generation=self._state.generation, - t2s_generation=self._state.t2s_generation, - vits_generation=self._state.vits_generation, - updated_at=self._state.updated_at, - ) - - -@dataclass -class DirectTTSExecution: - media_type: str - streaming: bool - audio_generator: Optional[Generator[bytes, None, None]] = None - audio_bytes: Optional[bytes] = None - request_id: Optional[str] = None - - -@dataclass -class NormalizedEngineRequest: - request_id: str - text: str - text_lang: str - ref_audio_path: str - prompt_lang: str - prompt_text: str = "" - aux_ref_audio_paths: List[str] | None = None - top_k: int = 15 - top_p: float = 1.0 - temperature: float = 1.0 - repetition_penalty: float = 1.35 - early_stop_num: int = -1 - ready_step: int = 0 - text_split_method: str = "cut5" - batch_size: int = 1 - batch_threshold: float = 0.75 - split_bucket: bool = False - speed_factor: float = 1.0 - fragment_interval: float = 0.3 - seed: int = -1 - media_type: str = "wav" - streaming_mode: bool | int = False - return_fragment: bool = False - fixed_length_chunk: bool = False - response_streaming: bool = False - parallel_infer: bool = False - sample_steps: int = 32 - super_sampling: bool = False - overlap_length: int = 2 - min_chunk_length: int = 16 - timeout_sec: float | None = None - - def to_payload(self) -> Dict[str, Any]: - return { - "request_id": self.request_id, - "text": self.text, - "text_lang": self.text_lang, - "ref_audio_path": self.ref_audio_path, - "aux_ref_audio_paths": list(self.aux_ref_audio_paths) if self.aux_ref_audio_paths else None, - "prompt_text": self.prompt_text, - "prompt_lang": self.prompt_lang, - "top_k": self.top_k, - "top_p": self.top_p, - "temperature": self.temperature, - "text_split_method": self.text_split_method, - "batch_size": self.batch_size, - "batch_threshold": self.batch_threshold, - "speed_factor": self.speed_factor, - "split_bucket": self.split_bucket, - "fragment_interval": self.fragment_interval, - "seed": self.seed, - "media_type": self.media_type, - "streaming_mode": self.streaming_mode, - "return_fragment": self.return_fragment, - "fixed_length_chunk": self.fixed_length_chunk, - "response_streaming": self.response_streaming, - "parallel_infer": self.parallel_infer, - "repetition_penalty": self.repetition_penalty, - "sample_steps": self.sample_steps, - "super_sampling": self.super_sampling, - "overlap_length": self.overlap_length, - "min_chunk_length": self.min_chunk_length, - "early_stop_num": self.early_stop_num, - "ready_step": self.ready_step, - "timeout_sec": self.timeout_sec, - } - - def to_scheduler_spec(self) -> SchedulerRequestSpec: - return SchedulerRequestSpec( - request_id=self.request_id, - ref_audio_path=Path(self.ref_audio_path), - prompt_text=self.prompt_text, - prompt_lang=self.prompt_lang, - text=self.text, - text_lang=self.text_lang, - top_k=self.top_k, - top_p=self.top_p, - temperature=self.temperature, - repetition_penalty=self.repetition_penalty, - early_stop_num=self.early_stop_num, - ready_step=self.ready_step, - ) - - -@dataclass -class SchedulerDebugExecution: - payload: Dict[str, Any] - - -@dataclass -class SchedulerSubmitExecution: - audio_bytes: bytes - media_type: str - headers: Dict[str, str] - - -@dataclass -class EnginePolicyConfig: - enabled: bool = True - poll_wait_ms: float = 5.0 - decode_backlog_soft_max: int = 0 - finalize_pending_soft_max: int = 0 - prepare_inflight_soft_max: int = 0 - active_decode_soft_max: int = 0 - ready_for_prefill_soft_max: int = 0 - active_request_soft_max: int = 0 - - def to_dict(self) -> Dict[str, Any]: - return { - "enabled": bool(self.enabled), - "poll_wait_ms": float(self.poll_wait_ms), - "decode_backlog_soft_max": int(self.decode_backlog_soft_max), - "finalize_pending_soft_max": int(self.finalize_pending_soft_max), - "prepare_inflight_soft_max": int(self.prepare_inflight_soft_max), - "active_decode_soft_max": int(self.active_decode_soft_max), - "ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max), - "active_request_soft_max": int(self.active_request_soft_max), - } - - -@dataclass -class EngineArbiterConfig: - poll_wait_ms: float = 5.0 - decode_burst: int = 4 - prepare_aging_ms: float = 10.0 - finalize_aging_ms: float = 10.0 - - def to_dict(self) -> Dict[str, Any]: - return { - "poll_wait_ms": float(self.poll_wait_ms), - "decode_burst": int(self.decode_burst), - "prepare_aging_ms": float(self.prepare_aging_ms), - "finalize_aging_ms": float(self.finalize_aging_ms), - } - - -class EngineStatus: - NEW = "NEW" - QUEUED = "QUEUED" - VALIDATED = "VALIDATED" - CPU_PREPARING = "CPU_PREPARING" - GPU_PREPARING = "GPU_PREPARING" - READY_FOR_PREFILL = "READY_FOR_PREFILL" - ACTIVE_DECODE = "ACTIVE_DECODE" - READY_FOR_FINALIZE = "READY_FOR_FINALIZE" - FINALIZING = "FINALIZING" - STREAMING = "STREAMING" - COMPLETED = "COMPLETED" - FAILED = "FAILED" - - -@dataclass -class EngineRequestState: - request_id: str - api_mode: str - backend: str - media_type: str - response_streaming: bool - submit_ts: float - deadline_ts: float | None = None - status: str = EngineStatus.NEW - updated_ts: float = 0.0 - error: str | None = None - finish_reason: str | None = None - meta: Dict[str, Any] = field(default_factory=dict) - profile: Dict[str, Any] = field(default_factory=dict) - lifecycle_timestamps: Dict[str, float] = field(default_factory=dict) - - def to_summary(self) -> Dict[str, Any]: - return { - "request_id": self.request_id, - "api_mode": self.api_mode, - "backend": self.backend, - "media_type": self.media_type, - "response_streaming": self.response_streaming, - "status": self.status, - "submit_ts": self.submit_ts, - "updated_ts": self.updated_ts, - "deadline_ts": self.deadline_ts, - "error": self.error, - "finish_reason": self.finish_reason, - "meta": dict(self.meta), - "profile": dict(self.profile), - "lifecycle_timestamps": dict(self.lifecycle_timestamps), - } - - -class EngineRequestRegistry: - def __init__(self, recent_limit: int) -> None: - self.lock = threading.Lock() - self.active_requests: Dict[str, EngineRequestState] = {} - self.recent_requests: Deque[EngineRequestState] = deque() - self.recent_limit = max(1, int(recent_limit)) - - def register( - self, - *, - request_id: str, - api_mode: str, - backend: str, - media_type: str, - response_streaming: bool, - deadline_ts: float | None = None, - meta: Optional[Dict[str, Any]] = None, - ) -> EngineRequestState: - now = time.perf_counter() - state = EngineRequestState( - request_id=request_id, - api_mode=api_mode, - backend=backend, - media_type=media_type, - response_streaming=bool(response_streaming), - submit_ts=now, - deadline_ts=deadline_ts, - updated_ts=now, - meta=dict(meta or {}), - lifecycle_timestamps={EngineStatus.NEW: now}, - ) - with self.lock: - self.active_requests[request_id] = state - return state - - def _move_to_recent_locked(self, state: EngineRequestState) -> None: - self.recent_requests.appendleft(state) - while len(self.recent_requests) > self.recent_limit: - self.recent_requests.pop() - - @staticmethod - def _apply_state_extra(state: EngineRequestState, extra: Optional[Dict[str, Any]]) -> None: - if not extra: - return - payload = dict(extra) - backend = payload.pop("backend", None) - if backend is not None: - state.backend = str(backend) - finish_reason = payload.pop("finish_reason", None) - if finish_reason is not None: - state.finish_reason = str(finish_reason) - error = payload.pop("error", None) - if error is not None: - state.error = str(error) - state.profile.update(payload) - - def update(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: - now = time.perf_counter() - with self.lock: - state = self.active_requests.get(request_id) - if state is None: - return - state.status = str(status) - state.updated_ts = now - state.lifecycle_timestamps[str(status)] = now - self._apply_state_extra(state, extra) - - def merge_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - if not extra: - return - now = time.perf_counter() - with self.lock: - state = self.active_requests.get(request_id) - if state is None: - for recent_state in self.recent_requests: - if recent_state.request_id == request_id: - state = recent_state - break - if state is None: - return - state.updated_ts = now - self._apply_state_extra(state, extra) - - def complete(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - now = time.perf_counter() - with self.lock: - state = self.active_requests.pop(request_id, None) - if state is None: - return - state.status = EngineStatus.COMPLETED - state.updated_ts = now - state.lifecycle_timestamps[EngineStatus.COMPLETED] = now - self._apply_state_extra(state, extra) - self._move_to_recent_locked(state) - - def fail(self, request_id: str, error: str) -> None: - now = time.perf_counter() - with self.lock: - state = self.active_requests.pop(request_id, None) - if state is None: - return - state.status = EngineStatus.FAILED - state.updated_ts = now - state.error = str(error) - state.lifecycle_timestamps[EngineStatus.FAILED] = now - self._move_to_recent_locked(state) - - def snapshot(self) -> Dict[str, Any]: - with self.lock: - active = [state.to_summary() for state in self.active_requests.values()] - recent = [state.to_summary() for state in list(self.recent_requests)] - recent_limit = self.recent_limit - active.sort(key=lambda item: item["submit_ts"]) - return { - "active_count": len(active), - "recent_count": len(recent), - "recent_limit": recent_limit, - "active_requests": active, - "recent_requests": recent, - } - - def collect_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: - requested = set(request_ids) - results: List[Dict[str, Any]] = [] - with self.lock: - for state in self.active_requests.values(): - if state.request_id in requested: - results.append(state.to_summary()) - existing_ids = {item["request_id"] for item in results} - for state in self.recent_requests: - if state.request_id in requested and state.request_id not in existing_ids: - results.append(state.to_summary()) - results.sort(key=lambda item: item["request_id"]) - return results - - def has_active(self, request_id: str) -> bool: - with self.lock: - return request_id in self.active_requests - - -@dataclass -class SchedulerPendingJob: - request_id: str - state: T2SRequestState - done_event: threading.Event - done_loop: asyncio.AbstractEventLoop | None - done_future: asyncio.Future | None - enqueue_time: float - speed_factor: float - sample_steps: int - media_type: str - admission_wait_ms: float = 0.0 - engine_policy_wait_ms: float = 0.0 - engine_dispatch_wait_ms: float = 0.0 - prepare_wall_ms: float = 0.0 - prepare_profile_total_ms: float = 0.0 - first_schedule_time: float | None = None - prefill_ms: float = 0.0 - merge_ms: float = 0.0 - decode_ms: float = 0.0 - finalize_wait_ms: float = 0.0 - synth_ms: float = 0.0 - pack_ms: float = 0.0 - decode_steps: int = 0 - result_ready_time: float | None = None - result: dict | None = None - sample_rate: int | None = None - audio_data: np.ndarray | None = None - error: str | None = None - engine_request_id: str | None = None - - -class SchedulerJobRegistry: - def __init__(self, lock: threading.Lock | threading.RLock | threading.Condition) -> None: - self._lock = lock - self._job_map: Dict[str, SchedulerPendingJob] = {} - self._total_submitted = 0 - self._total_finished = 0 - - def register(self, job: SchedulerPendingJob, *, keep_job: bool = True) -> None: - with self._lock: - if keep_job: - self._job_map[job.request_id] = job - self._total_submitted += 1 - - def get(self, request_id: str) -> SchedulerPendingJob | None: - with self._lock: - return self._job_map.get(request_id) - - def pop(self, request_id: str) -> SchedulerPendingJob | None: - with self._lock: - return self._job_map.pop(request_id, None) - - def remove(self, request_id: str) -> None: - with self._lock: - self._job_map.pop(request_id, None) - - def mark_finished(self) -> None: - with self._lock: - self._total_finished += 1 - - def mark_finished_and_remove(self, request_id: str) -> None: - with self._lock: - self._job_map.pop(request_id, None) - self._total_finished += 1 - - def is_empty(self) -> bool: - with self._lock: - return not self._job_map - - def submitted_count(self) -> int: - with self._lock: - return int(self._total_submitted) - - def finished_count(self) -> int: - with self._lock: - return int(self._total_finished) - - def snapshot(self, max_request_ids: int = 32) -> Dict[str, Any]: - with self._lock: - request_ids = list(self._job_map.keys()) - return { - "job_count": int(len(request_ids)), - "request_ids": request_ids[: max(0, int(max_request_ids))], - "total_submitted": int(self._total_submitted), - "total_finished": int(self._total_finished), - } - - -class EngineTaskQueueOwner: - def __init__(self, completion_key: str = "total_completed") -> None: - self.condition = threading.Condition() - self.queue: Deque[Any] = deque() - self.total_submitted = 0 - self.total_completed = 0 - self.peak_waiting = 0 - self.completion_key = str(completion_key) - - def enqueue(self, item: Any) -> None: - with self.condition: - self.queue.append(item) - self.total_submitted += 1 - self.peak_waiting = max(self.peak_waiting, len(self.queue)) - self.condition.notify_all() - - def enqueue_many(self, items: Sequence[Any]) -> None: - if not items: - return - with self.condition: - for item in items: - self.queue.append(item) - self.total_submitted += len(items) - self.peak_waiting = max(self.peak_waiting, len(self.queue)) - self.condition.notify_all() - - def pop_left(self) -> Any | None: - with self.condition: - if not self.queue: - return None - return self.queue.popleft() - - def mark_completed(self, count: int = 1, *, notify: bool = False) -> None: - if count <= 0: - return - with self.condition: - self.total_completed += int(count) - if notify: - self.condition.notify_all() - - def has_items(self) -> bool: - with self.condition: - return bool(self.queue) - - def waiting_count(self) -> int: - with self.condition: - return int(len(self.queue)) - - def snapshot(self, *, max_request_ids: int = 16, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - with self.condition: - waiting_items = list(self.queue)[: max(0, int(max_request_ids))] - snapshot = { - "waiting_count": int(len(self.queue)), - "waiting_request_ids": [str(getattr(item, "request_id", "")) for item in waiting_items], - "peak_waiting": int(self.peak_waiting), - "total_submitted": int(self.total_submitted), - self.completion_key: int(self.total_completed), - } - if extra: - snapshot.update(dict(extra)) - return snapshot - - def peek_oldest_age_ms(self, timestamp_attr: str) -> float: - with self.condition: - if not self.queue: - return 0.0 - enqueue_time = float(getattr(self.queue[0], timestamp_attr)) - return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) - - def is_drained(self) -> bool: - with self.condition: - return not self.queue and self.total_submitted == self.total_completed - - def take_finalize_batch( - self, - *, - finalize_mode: str, - batch_max_items: int, - batch_wait_s: float, - use_vocoder: bool, - ) -> List[SchedulerFinalizeTask]: - with self.condition: - if not self.queue: - return [] - selected_tasks = [self.queue.popleft()] - if finalize_mode == "sync" or use_vocoder: - return selected_tasks - if batch_max_items <= 1: - return selected_tasks - first_task = selected_tasks[0] - oldest_age_s = max(0.0, time.perf_counter() - first_task.enqueued_time) - if len(self.queue) + 1 < batch_max_items and oldest_age_s < batch_wait_s: - self.queue.appendleft(first_task) - return [] - while len(selected_tasks) < batch_max_items: - if not self.queue: - break - matched_index = None - for index, task in enumerate(self.queue): - if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: - matched_index = index - break - if matched_index is None: - break - selected_tasks.append(self.queue[matched_index]) - del self.queue[matched_index] - return selected_tasks - - -class EnginePolicyArbiterController: - def __init__( - self, - *, - policy_config: EnginePolicyConfig, - arbiter_config: EngineArbiterConfig, - snapshot_request_registry: Callable[[], Dict[str, Any]], - get_worker_state: Callable[[], Dict[str, Any]], - snapshot_prepare_state: Callable[[], Dict[str, Any]], - snapshot_finalize_state: Callable[[], Dict[str, Any]], - snapshot_dispatch_state: Callable[[], Dict[str, Any]], - snapshot_decode_runtime_state: Callable[[], Dict[str, Any]], - snapshot_job_registry: Callable[[], Dict[str, Any]], - peek_queue_age_ms: Callable[[str], float], - merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], - ) -> None: - self.policy_config = policy_config - self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0) - self.arbiter_config = arbiter_config - self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0) - self.condition = threading.Condition() - self.state = EngineArbiterState( - decode_budget_remaining=int(self.arbiter_config.decode_burst), - last_observed_at=time.perf_counter(), - ) - self.snapshot_request_registry = snapshot_request_registry - self.get_worker_state = get_worker_state - self.snapshot_prepare_state = snapshot_prepare_state - self.snapshot_finalize_state = snapshot_finalize_state - self.snapshot_dispatch_state = snapshot_dispatch_state - self.snapshot_decode_runtime_state = snapshot_decode_runtime_state - self.snapshot_job_registry = snapshot_job_registry - self.peek_queue_age_ms = peek_queue_age_ms - self.merge_request_state_profile = merge_request_state_profile - - def snapshot_state(self) -> Dict[str, Any]: - with self.condition: - return { - "config": self.arbiter_config.to_dict(), - "total_ticks": int(self.state.total_ticks), - "total_idle_ticks": int(self.state.total_idle_ticks), - "total_prepare_dispatches": int(self.state.total_prepare_dispatches), - "total_decode_dispatches": int(self.state.total_decode_dispatches), - "total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks), - "total_finalize_dispatches": int(self.state.total_finalize_dispatches), - "decode_budget_remaining": int(self.state.decode_budget_remaining), - "last_stage": str(self.state.last_stage), - "last_reason": str(self.state.last_reason), - "last_policy_allowed": bool(self.state.last_policy_allowed), - "last_observed_at": float(self.state.last_observed_at), - } - - def notify(self) -> None: - with self.condition: - self.condition.notify_all() - - def wait(self) -> None: - with self.condition: - self.condition.wait(timeout=self.arbiter_poll_s) - - def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: - with self.condition: - self.state.total_ticks += 1 - if stage == "idle": - self.state.total_idle_ticks += 1 - elif stage == "prepare": - self.state.total_prepare_dispatches += 1 - self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) - elif stage == "finalize": - self.state.total_finalize_dispatches += 1 - self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) - elif stage == "decode_dispatch": - self.state.total_decode_dispatches += 1 - elif stage == "decode_runtime": - self.state.total_decode_runtime_ticks += 1 - self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1) - self.state.last_stage = str(stage) - self.state.last_reason = str(reason) - self.state.last_policy_allowed = bool(policy_allowed) - self.state.last_observed_at = time.perf_counter() - - def build_stage_counters( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - prepare_dispatcher_state = self.snapshot_prepare_state() - finalize_dispatcher_state = self.snapshot_finalize_state() - dispatcher_state = self.snapshot_dispatch_state() - active_requests = list(request_registry.get("active_requests", [])) - status_counts: Dict[str, int] = {} - for item in active_requests: - status = str(item.get("status", "UNKNOWN")) - status_counts[status] = status_counts.get(status, 0) + 1 - - worker_pending_jobs = int(worker_state.get("pending_jobs", 0)) - worker_decode_active_size = int(worker_state.get("running_requests", 0)) - worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0)) - worker_finalize_pending = int(worker_state.get("finalize_pending", 0)) - worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0)) - engine_decode_runtime_state = self.snapshot_decode_runtime_state() - engine_job_registry = self.snapshot_job_registry() - decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0)) - decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0)) - return { - "active_request_count": int(len(active_requests)), - "status_counts": status_counts, - "queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)), - "cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)), - "gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)), - "ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)), - "active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)), - "ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)), - "finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)), - "streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)), - "worker_pending_jobs": worker_pending_jobs, - "worker_decode_active_size": worker_decode_active_size, - "worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)), - "worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)), - "engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs, - "engine_decode_runtime_active_request_count": decode_runtime_active_size, - "engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)), - "engine_job_registry_count": int(engine_job_registry.get("job_count", 0)), - "worker_prepare_inflight": worker_prepare_inflight, - "worker_finalize_pending": worker_finalize_pending, - "worker_finalize_inflight": worker_finalize_inflight, - "engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)), - "engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)), - "engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)), - "decode_backlog": int( - decode_runtime_pending_jobs + decode_runtime_active_size - if bool(worker_state.get("engine_decode_control_enabled", False)) - else worker_pending_jobs + worker_decode_active_size - ), - } - - def build_policy_snapshot( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - counters = self.build_stage_counters(request_registry, worker_state) - config = self.policy_config.to_dict() - blocked_reasons: List[Dict[str, Any]] = [] - finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0)) - limit_checks = [ - ("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])), - ("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])), - ("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])), - ("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])), - ("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])), - ("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])), - ] - if bool(config["enabled"]): - for name, value, limit in limit_checks: - if limit > 0 and int(value) >= int(limit): - blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)}) - return { - "enabled": bool(config["enabled"]), - "allowed": (not bool(config["enabled"])) or not blocked_reasons, - "blocked_reasons": blocked_reasons, - "config": config, - "metrics": { - "active_request_count": int(counters["active_request_count"]), - "queued_request_count": int(counters["queued_request_count"]), - "ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]), - "active_decode_request_count": int(counters["active_decode_request_count"]), - "engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]), - "engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]), - "decode_backlog": int(counters["decode_backlog"]), - "prepare_inflight": int(counters["worker_prepare_inflight"]), - "finalize_pending": int(finalize_pending_total), - "engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)), - "finalize_inflight": int(counters["worker_finalize_inflight"]), - }, - "observed_at": time.perf_counter(), - } - - async def wait_for_policy_admission( - self, - *, - request_id: str | None, - timeout_sec: float | None, - ) -> tuple[float, Dict[str, Any]]: - request_registry = self.snapshot_request_registry() - worker_state = self.get_worker_state() - snapshot = self.build_policy_snapshot(request_registry, worker_state) - if not self.policy_config.enabled: - return 0.0, snapshot - start = time.perf_counter() - deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) - while True: - request_registry = self.snapshot_request_registry() - worker_state = self.get_worker_state() - snapshot = self.build_policy_snapshot(request_registry, worker_state) - if snapshot["allowed"]: - wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0) - if request_id not in [None, ""]: - self.merge_request_state_profile( - str(request_id), - { - "engine_policy_wait_ms": float(wait_ms), - "engine_policy_snapshot": snapshot, - }, - ) - return wait_ms, snapshot - now = time.perf_counter() - if deadline is not None and now >= deadline: - blocked_summary = ", ".join( - f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", []) - ) - raise TimeoutError(f"engine policy admission timeout ({blocked_summary})") - await asyncio.sleep(self.policy_poll_s) - - def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: - request_registry = self.snapshot_request_registry() - worker_state = self.get_worker_state() - policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) - prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0)) - finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) - decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) - decode_runtime_state = self.snapshot_decode_runtime_state() - worker_decode_has_work = bool(decode_runtime_state.get("has_work", False)) - worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False)) - worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) - worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) - prepare_age_ms = float(self.peek_queue_age_ms("prepare")) - finalize_age_ms = float(self.peek_queue_age_ms("finalize")) - decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) - decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) - policy_allowed = bool(policy_snapshot.get("allowed", True)) - if ( - worker_decode_control_enabled - and worker_decode_has_work - and policy_allowed - and decode_budget_remaining > 0 - and (worker_running_requests > 0 or worker_pending_jobs > 0) - ): - return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state - if ( - worker_decode_control_enabled - and worker_pending_jobs > 0 - and policy_allowed - and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms) - ): - return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state - if ( - decode_waiting > 0 - and policy_allowed - and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) - ): - return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state - if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): - return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state - if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): - return "finalize", "finalize_aging", policy_snapshot, worker_state - if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): - return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state - if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): - return "prepare", "prepare_aging", policy_snapshot, worker_state - if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: - return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state - if decode_waiting > 0 and policy_allowed: - return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state - if finalize_waiting > 0: - return "finalize", "finalize_fallback", policy_snapshot, worker_state - if prepare_waiting > 0: - return "prepare", "prepare_fallback", policy_snapshot, worker_state - return "idle", "no_pending_work", policy_snapshot, worker_state - - -class EngineDecodeRuntimeOwner: - def __init__( - self, - *, - get_decode_runtime_counters: Callable[[], Dict[str, int]], - get_micro_batch_wait_s: Callable[[], float], - ) -> None: - self.get_decode_runtime_counters = get_decode_runtime_counters - self.get_micro_batch_wait_s = get_micro_batch_wait_s - self.condition = threading.Condition() - self.pending_jobs: Deque[SchedulerPendingJob] = deque() - self.active_batch: T2SActiveBatch | None = None - self.state_lock = threading.Lock() - self.state = EngineDecodeRuntimeState(updated_at=time.perf_counter()) - - @staticmethod - def summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: - if active_batch is None: - return {} - decode_step_index_max = 0 - if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0: - decode_step_index_max = int(active_batch.step_indices.max().item()) - return { - "request_count": int(len(active_batch.request_ids)), - "request_ids": list(active_batch.request_ids), - "prefill_done": bool(active_batch.prefill_done), - "decode_step_index_max": int(decode_step_index_max), - } - - def snapshot_pending_queue_state(self) -> Dict[str, Any]: - with self.condition: - return { - "pending_jobs": int(len(self.pending_jobs)), - "pending_request_ids": [job.request_id for job in list(self.pending_jobs)[:32]], - } - - def enqueue_pending_job(self, job: SchedulerPendingJob) -> None: - with self.condition: - self.pending_jobs.append(job) - self.condition.notify_all() - self.refresh_state("engine_decode_pending_enqueue") - - def take_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - with self.condition: - if not self.pending_jobs: - return [] - if wait_for_batch: - oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) - if (time.perf_counter() - oldest_enqueue_time) < self.get_micro_batch_wait_s(): - return [] - pending_jobs = list(self.pending_jobs) - self.pending_jobs.clear() - self.refresh_state("engine_decode_pending_dequeue") - return pending_jobs - - def pending_age_ms(self) -> float: - with self.condition: - if not self.pending_jobs: - return 0.0 - enqueue_time = float(self.pending_jobs[0].enqueue_time) - return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) - - def has_pending_jobs(self) -> bool: - with self.condition: - return bool(self.pending_jobs) - - def get_active_batch(self) -> T2SActiveBatch | None: - return self.active_batch - - def set_active_batch(self, active_batch: T2SActiveBatch | None) -> None: - self.active_batch = active_batch - - def active_batch_summary(self) -> Dict[str, Any]: - return self.summarize_active_batch(self.active_batch) - - def refresh_state(self, last_event: str) -> None: - pending_state = self.snapshot_pending_queue_state() - active_batch_summary = self.active_batch_summary() - worker_decode_counters = self.get_decode_runtime_counters() - with self.state_lock: - self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) - self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) - self.state.active_request_count = int(active_batch_summary.get("request_count", 0)) - self.state.active_request_ids = list(active_batch_summary.get("request_ids", []))[:32] - self.state.prefill_done = bool(active_batch_summary.get("prefill_done", False)) - self.state.decode_step_index_max = int(active_batch_summary.get("decode_step_index_max", 0)) - self.state.total_cycles = int(worker_decode_counters.get("total_cycles", 0)) - self.state.prefill_cycles = int(worker_decode_counters.get("prefill_cycles", 0)) - self.state.step_cycles = int(worker_decode_counters.get("step_cycles", 0)) - self.state.has_work = bool(pending_state.get("pending_jobs", 0) or active_batch_summary.get("request_count", 0)) - self.state.last_event = str(last_event) - self.state.updated_at = float(time.perf_counter()) - - def update_from_worker_snapshot(self, snapshot: Dict[str, Any]) -> None: - if not snapshot: - return - pending_state = self.snapshot_pending_queue_state() - with self.state_lock: - self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) - self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) - self.state.active_request_count = int(snapshot.get("active_request_count", 0)) - self.state.active_request_ids = list(snapshot.get("active_request_ids", []))[:32] - self.state.prefill_done = bool(snapshot.get("prefill_done", False)) - self.state.decode_step_index_max = int(snapshot.get("decode_step_index_max", 0)) - self.state.total_cycles = int(snapshot.get("total_cycles", 0)) - self.state.prefill_cycles = int(snapshot.get("prefill_cycles", 0)) - self.state.step_cycles = int(snapshot.get("step_cycles", 0)) - self.state.has_work = bool( - pending_state.get("pending_jobs", 0) - or snapshot.get("active_request_count", 0) - or snapshot.get("has_work", False) - ) - self.state.last_event = str(snapshot.get("last_event", "unknown")) - self.state.updated_at = float(snapshot.get("updated_at", time.perf_counter())) - - def snapshot_state(self) -> Dict[str, Any]: - pending_state = self.snapshot_pending_queue_state() - active_batch_summary = self.active_batch_summary() - worker_decode_counters = self.get_decode_runtime_counters() - with self.state_lock: - return { - "pending_jobs": int(pending_state.get("pending_jobs", self.state.pending_jobs)), - "pending_request_ids": list(pending_state.get("pending_request_ids", self.state.pending_request_ids)), - "active_request_count": int(active_batch_summary.get("request_count", self.state.active_request_count)), - "active_request_ids": list(active_batch_summary.get("request_ids", self.state.active_request_ids)), - "prefill_done": bool(active_batch_summary.get("prefill_done", self.state.prefill_done)), - "decode_step_index_max": int( - active_batch_summary.get("decode_step_index_max", self.state.decode_step_index_max) - ), - "total_cycles": int(worker_decode_counters.get("total_cycles", 0)), - "prefill_cycles": int(worker_decode_counters.get("prefill_cycles", 0)), - "step_cycles": int(worker_decode_counters.get("step_cycles", 0)), - "has_work": bool( - pending_state.get("pending_jobs", 0) - or active_batch_summary.get("request_count", self.state.active_request_count) - or self.state.has_work - ), - "last_event": str(self.state.last_event), - "updated_at": float(self.state.updated_at), - } - -@dataclass -class SchedulerFinalizeTask: - request_id: str - item: T2SFinishedItem - enqueued_time: float - - -@dataclass -class EngineDispatchTask: - request_id: str - state: T2SRequestState - speed_factor: float - sample_steps: int - media_type: str - prepare_wall_ms: float - prepare_profile_total_ms: float - done_loop: asyncio.AbstractEventLoop | None - done_future: asyncio.Future | None - engine_request_id: str | None - timeout_sec: float | None - enqueue_time: float - worker_job: SchedulerPendingJob | None = None - engine_policy_wait_ms: float = 0.0 - engine_dispatch_wait_ms: float = 0.0 - engine_policy_snapshot: Dict[str, Any] | None = None - error: str | None = None - - -@dataclass -class EngineGpuPrepareTask: - request_id: str - cpu_stage: PreparedCpuStage - done_loop: asyncio.AbstractEventLoop | None - done_future: asyncio.Future | None - engine_request_id: str | None - enqueue_time: float - queue_wait_ms: float = 0.0 - error: str | None = None - - -@dataclass -class EngineFinalizeQueueState: - waiting_count: int - waiting_request_ids: List[str] - peak_waiting: int - total_submitted: int - total_completed: int - - -@dataclass -class EngineArbiterState: - total_ticks: int = 0 - total_idle_ticks: int = 0 - total_prepare_dispatches: int = 0 - total_decode_dispatches: int = 0 - total_decode_runtime_ticks: int = 0 - total_finalize_dispatches: int = 0 - decode_budget_remaining: int = 0 - last_stage: str = "idle" - last_reason: str = "init" - last_observed_at: float = 0.0 - last_policy_allowed: bool = True - - -@dataclass -class EngineDecodeRuntimeState: - pending_jobs: int = 0 - pending_request_ids: List[str] = field(default_factory=list) - active_request_count: int = 0 - active_request_ids: List[str] = field(default_factory=list) - prefill_done: bool = False - decode_step_index_max: int = 0 - total_cycles: int = 0 - prefill_cycles: int = 0 - step_cycles: int = 0 - has_work: bool = False - last_event: str = "init" - updated_at: float = 0.0 - - -@dataclass -class RuntimeStateCallbacks: - update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None - complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None - fail: Callable[[str, str], None] | None = None - decode_runtime_update: Callable[[Dict[str, Any]], None] | None = None - - -class WorkerPrepareExecutor: - def __init__( - self, - tts: TTS, - on_state_change: Callable[[], None] | None = None, - ) -> None: - self.coordinator = PrepareCoordinator(tts) - self.on_state_change = on_state_change - - def _notify_state_change(self) -> None: - if self.on_state_change is None: - return - try: - self.on_state_change() - except Exception: - pass - - def snapshot(self) -> Dict[str, int]: - return dict(self.coordinator.snapshot()) - - def get_max_inflight(self) -> int: - return int(self.coordinator.snapshot().get("max_inflight", 0)) - - def is_idle(self) -> bool: - return int(self.coordinator.snapshot().get("inflight", 0)) <= 0 - - async def prepare_state_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> tuple[T2SRequestState, float, float]: - try: - return await self.coordinator.prepare_state_profiled_async(spec, prepare_submit_at) - finally: - self._notify_state_change() - - async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: - results = await asyncio.gather( - *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] - ) - return [state for state, _, _ in results] - - async def prepare_cpu_stage_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> PreparedCpuStage: - try: - return await self.coordinator.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) - finally: - self._notify_state_change() - - async def prepare_gpu_stage_profiled_async( - self, - cpu_stage: PreparedCpuStage, - ) -> tuple[T2SRequestState, float, float]: - try: - return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) - finally: - self._notify_state_change() - - -class WorkerFinalizeExecutor: - def __init__( - self, - tts: TTS, - on_state_change: Callable[[], None] | None = None, - external_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, - ) -> None: - self.tts = tts - self.on_state_change = on_state_change - self.external_submit = external_submit - self.condition = threading.Condition() - self.pending_tasks: Deque[SchedulerFinalizeTask] = deque() - self.pending_peak = 0 - self.inflight = 0 - self.inflight_peak = 0 - self.worker_count = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) - self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() - self.batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) - self.batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) - - def _notify_state_change(self) -> None: - if self.on_state_change is None: - return - try: - self.on_state_change() - except Exception: - pass - - def get_worker_count(self) -> int: - return int(self.worker_count) - - def get_batch_policy(self) -> Dict[str, Any]: - return { - "finalize_mode": str(self.finalize_mode), - "finalize_batch_max_items": int(self.batch_max_items), - "finalize_batch_wait_s": float(self.batch_wait_s), - } - - def get_pending_count(self) -> int: - with self.condition: - return int(len(self.pending_tasks)) - - def snapshot(self) -> Dict[str, Any]: - with self.condition: - return { - "finalize_pending": int(len(self.pending_tasks)), - "finalize_pending_peak": int(self.pending_peak), - "finalize_inflight": int(self.inflight), - "finalize_inflight_peak": int(self.inflight_peak), - "finalize_workers": int(self.worker_count), - "finalize_mode": str(self.finalize_mode), - "finalize_batch_max_items": int(self.batch_max_items), - "finalize_batch_wait_ms": float(self.batch_wait_s * 1000.0), - } - - def is_idle(self) -> bool: - with self.condition: - return self.inflight <= 0 and not self.pending_tasks - - def enqueue_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: - if not tasks: - return - if self.external_submit is not None: - self.external_submit(tasks) - self._notify_state_change() - return - with self.condition: - for task in tasks: - self.pending_tasks.append(task) - self.pending_peak = max(self.pending_peak, len(self.pending_tasks)) - self.condition.notify_all() - self._notify_state_change() - - def begin_execution(self, task_count: int) -> None: - if task_count <= 0: - return - with self.condition: - self.inflight += int(task_count) - self.inflight_peak = max(self.inflight_peak, self.inflight) - self.condition.notify_all() - self._notify_state_change() - - def end_execution(self, task_count: int) -> None: - with self.condition: - self.inflight = max(0, self.inflight - int(task_count)) - self.condition.notify_all() - self._notify_state_change() - - def take_task_batch_blocking(self) -> List[SchedulerFinalizeTask]: - with self.condition: - while not self.pending_tasks: - self.condition.wait() - selected_tasks = [self.pending_tasks.popleft()] - if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: - self.inflight += len(selected_tasks) - self.inflight_peak = max(self.inflight_peak, self.inflight) - self._notify_state_change() - return selected_tasks - batch_deadline = time.perf_counter() + self.batch_wait_s - while len(selected_tasks) < self.batch_max_items: - if not self.pending_tasks: - remaining = batch_deadline - time.perf_counter() - if remaining <= 0: - break - self.condition.wait(timeout=remaining) - continue - first_task = selected_tasks[0] - matched_index = None - for index, task in enumerate(self.pending_tasks): - if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: - matched_index = index - break - if matched_index is not None: - selected_tasks.append(self.pending_tasks[matched_index]) - del self.pending_tasks[matched_index] - continue - remaining = batch_deadline - time.perf_counter() - if remaining <= 0: - break - self.condition.wait(timeout=remaining) - self.inflight += len(selected_tasks) - self.inflight_peak = max(self.inflight_peak, self.inflight) - self._notify_state_change() - return selected_tasks - - def _sync_device(self) -> None: - try: - device_str = str(self.tts.configs.device) - if device_str.startswith("cuda") and torch.cuda.is_available(): - torch.cuda.synchronize(self.tts.configs.device) - elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): - torch.mps.synchronize() - except Exception: - pass - - def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: - audio_fragment = self.tts.synthesize_audio_request_local( - semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), - phones=job.state.phones.detach().clone().unsqueeze(0), - prompt_semantic=job.state.prompt_semantic.detach().clone(), - prompt_phones=job.state.prompt_phones.detach().clone(), - refer_spec=( - job.state.refer_spec[0].detach().clone(), - None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), - ), - raw_audio=job.state.raw_audio.detach().clone(), - raw_sr=int(job.state.raw_sr), - speed=float(job.speed_factor), - sample_steps=int(job.sample_steps), - ) - output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] - return self.tts.audio_postprocess( - audio=[[audio_fragment]], - sr=int(output_sr), - batch_index_list=None, - speed_factor=float(job.speed_factor), - split_bucket=False, - fragment_interval=0.0, - super_sampling=False, - ) - - def _synthesize_finished_audio_batch( - self, - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], - ) -> List[tuple[int, np.ndarray]]: - semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] - phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] - refer_specs = [] - speeds = [] - sample_steps_list = [] - for job, _ in jobs_and_items: - refer_specs.append( - ( - job.state.refer_spec[0].detach().clone(), - None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), - ) - ) - speeds.append(float(job.speed_factor)) - sample_steps_list.append(int(job.sample_steps)) - audio_fragments = self.tts.synthesize_audio_requests_local_batched( - semantic_tokens_list=semantic_tokens_list, - phones_list=phones_list, - refer_specs=refer_specs, - speeds=speeds, - sample_steps_list=sample_steps_list, - ) - output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] - results: List[tuple[int, np.ndarray]] = [] - for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): - results.append( - self.tts.audio_postprocess( - audio=[[audio_fragment]], - sr=int(output_sr), - batch_index_list=None, - speed_factor=float(job.speed_factor), - split_bucket=False, - fragment_interval=0.0, - super_sampling=False, - ) - ) - return results - - def synthesize_finalize_jobs( - self, - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], - ) -> tuple[float, List[tuple[int, np.ndarray]]]: - if not jobs_and_items: - return 0.0, [] - self._sync_device() - synth_start = time.perf_counter() - if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: - job, item = jobs_and_items[0] - batch_results = [self._synthesize_finished_audio(job, item)] - else: - batch_results = self._synthesize_finished_audio_batch(jobs_and_items) - self._sync_device() - synth_ms = (time.perf_counter() - synth_start) * 1000.0 - return float(synth_ms), batch_results - - -class WorkerCompletionBridge: - def __init__(self, runtime_callbacks: RuntimeStateCallbacks | None = None) -> None: - self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() - - @staticmethod - def _resolve_done_future(job: SchedulerPendingJob) -> None: - future = job.done_future - if future is None or future.done(): - return - future.set_result(job) - - def notify_done_future(self, job: SchedulerPendingJob) -> None: - if job.done_loop is None or job.done_future is None: - return - try: - job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) - except RuntimeError: - pass - - def runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: - if request_id is None or self.runtime_callbacks.complete is None: - return - self.runtime_callbacks.complete(request_id, extra) - - def runtime_fail(self, request_id: str | None, error: str) -> None: - if request_id is None or self.runtime_callbacks.fail is None: - return - self.runtime_callbacks.fail(request_id, error) - - @staticmethod - def build_completed_job_result( - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - audio_data: np.ndarray, - finished_at: float | None = None, - ) -> Dict[str, Any]: - finished_at = float(time.perf_counter() if finished_at is None else finished_at) - queue_wait_ms = 0.0 - if job.first_schedule_time is not None: - queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) - worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) - worker_residual_ms = max( - 0.0, - worker_total_ms - - queue_wait_ms - - job.prefill_ms - - job.merge_ms - - job.decode_ms - - job.finalize_wait_ms - - job.synth_ms, - ) - worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) - job.sample_rate = int(sample_rate) - job.audio_data = audio_data - job.result_ready_time = finished_at - prepare_profile = dict(job.state.prepare_profile) - result = { - "request_id": item.request_id, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "finish_reason": item.finish_reason, - "decode_admission_wait_ms": float(job.admission_wait_ms), - "engine_policy_wait_ms": float(job.engine_policy_wait_ms), - "engine_dispatch_wait_ms": float(job.engine_dispatch_wait_ms), - "prepare_ms": job.prepare_wall_ms, - "prepare_wall_ms": job.prepare_wall_ms, - "prepare_profile_total_ms": job.prepare_profile_total_ms, - "prepare_profile": prepare_profile, - "queue_wait_ms": queue_wait_ms, - "prefill_ms": job.prefill_ms, - "merge_ms": job.merge_ms, - "decode_ms": job.decode_ms, - "finalize_wait_ms": job.finalize_wait_ms, - "synth_ms": job.synth_ms, - "worker_residual_ms": worker_residual_ms, - "worker_other_ms": worker_other_ms, - "worker_total_ms": worker_total_ms, - "decode_steps": int(job.decode_steps), - "sample_rate": int(sample_rate), - "media_type": job.media_type, - } - job.result = result - return result - - @staticmethod - def build_runtime_complete_payload( - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - ) -> Dict[str, Any]: - return { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "sample_rate": int(sample_rate), - "worker_profile": dict(job.result or {}), - } - - def complete_job( - self, - job: SchedulerPendingJob, - *, - runtime_request_id: str | None, - runtime_extra: Optional[Dict[str, Any]] = None, - remove_job: Callable[[], None] | None = None, - on_job_finished: Callable[[], None] | None = None, - notify_waiters: Callable[[], None] | None = None, - ) -> None: - job.done_event.set() - self.notify_done_future(job) - if remove_job is not None: - remove_job() - if on_job_finished is not None: - on_job_finished() - if notify_waiters is not None: - notify_waiters() - self.runtime_complete(runtime_request_id, runtime_extra) - - def fail_job( - self, - job: SchedulerPendingJob, - *, - error: str, - remove_job: Callable[[], None] | None = None, - on_job_finished: Callable[[], None] | None = None, - notify_waiters: Callable[[], None] | None = None, - ) -> None: - job.error = str(error) - job.done_event.set() - self.notify_done_future(job) - if remove_job is not None: - remove_job() - if on_job_finished is not None: - on_job_finished() - if notify_waiters is not None: - notify_waiters() - self.runtime_fail(job.engine_request_id, str(error)) - - def complete_finalize_task( - self, - *, - condition: threading.Condition, - job_registry: SchedulerJobRegistry, - job: SchedulerPendingJob, - item: T2SFinishedItem, - sample_rate: int, - audio_data: np.ndarray, - ) -> None: - runtime_extra: Optional[Dict[str, Any]] = None - with condition: - if job_registry.get(item.request_id) is not job: - return - self.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) - runtime_extra = self.build_runtime_complete_payload(job, item, sample_rate=sample_rate) - self.complete_job( - job, - runtime_request_id=job.engine_request_id, - runtime_extra=runtime_extra, - on_job_finished=lambda: job_registry.mark_finished_and_remove(item.request_id), - notify_waiters=condition.notify_all, - ) - - def fail_jobs( - self, - *, - condition: threading.Condition, - job_registry: SchedulerJobRegistry, - request_ids: List[str], - error: str, - ) -> None: - if not request_ids: - return - with condition: - for request_id in request_ids: - job = job_registry.get(request_id) - if job is None: - continue - self.fail_job( - job, - error=error, - on_job_finished=lambda rid=request_id: job_registry.mark_finished_and_remove(rid), - ) - condition.notify_all() - - -class WorkerDecodeExecutor: - def __init__(self, tts: TTS, max_steps: int) -> None: - self.tts = tts - self.max_steps = int(max_steps) - - def _sync_device(self) -> None: - try: - device_str = str(self.tts.configs.device) - if device_str.startswith("cuda") and torch.cuda.is_available(): - torch.cuda.synchronize(self.tts.configs.device) - elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): - torch.mps.synchronize() - except Exception: - pass - - def execute_prefill_merge( - self, - *, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], - add_prefill_time: Callable[[List[str], float], None] | None, - add_merge_time: Callable[[List[str], float], None] | None, - enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, - finalize_error: Callable[[List[str], str], None] | None, - ) -> Dict[str, Any]: - if not pending_jobs: - return { - "executed": False, - "active_batch": active_batch, - "pending_jobs": [], - "prefill_elapsed_s": 0.0, - "merge_elapsed_s": 0.0, - "finished_items": [], - "error": None, - "error_request_ids": [], - } - admitted_finished: List[T2SFinishedItem] = [] - prefill_elapsed_s = 0.0 - merge_elapsed_s = 0.0 - error: str | None = None - error_request_ids: List[str] = [] - try: - self._sync_device() - prefill_start = time.perf_counter() - mark_prefill_started(pending_jobs, prefill_start) - admitted_active_batch, admitted_finished = run_prefill_active_batch( - self.tts.t2s_model.model, - [job.state for job in pending_jobs], - max_steps=self.max_steps, - ) - self._sync_device() - prefill_elapsed_s = time.perf_counter() - prefill_start - if add_prefill_time is not None: - add_prefill_time([job.request_id for job in pending_jobs], prefill_elapsed_s) - if enqueue_finished is not None: - enqueue_finished(admitted_finished) - merge_start = time.perf_counter() - active_batch = merge_active_batches( - self.tts.t2s_model.model, - active_batch, - admitted_active_batch, - ) - merge_elapsed_s = time.perf_counter() - merge_start - if add_merge_time is not None: - add_merge_time( - [] if active_batch is None else list(active_batch.request_ids), - merge_elapsed_s, - ) - except Exception as exc: - error = str(exc) - error_request_ids = [job.request_id for job in pending_jobs] - if finalize_error is not None: - finalize_error(error_request_ids, error) - return { - "executed": True, - "active_batch": active_batch, - "pending_jobs": list(pending_jobs), - "prefill_elapsed_s": float(prefill_elapsed_s), - "merge_elapsed_s": float(merge_elapsed_s), - "finished_items": list(admitted_finished), - "error": error, - "error_request_ids": error_request_ids, - } - - def execute_decode_step( - self, - *, - active_batch: Optional[T2SActiveBatch], - add_decode_time: Callable[[List[str], float], None] | None, - enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, - finalize_error: Callable[[List[str], str], None] | None, - ) -> Dict[str, Any]: - if active_batch is None: - return { - "executed": False, - "active_batch": None, - "request_ids": [], - "decode_elapsed_s": 0.0, - "finished_items": [], - "error": None, - "error_request_ids": [], - } - active_request_ids: List[str] = [] - step_finished: List[T2SFinishedItem] = [] - decode_elapsed_s = 0.0 - error: str | None = None - error_request_ids: List[str] = [] - try: - active_request_ids = [state.request_id for state in active_batch.states] - self._sync_device() - decode_start = time.perf_counter() - active_batch, step_finished = decode_one_step( - self.tts.t2s_model.model, - active_batch, - max_steps=self.max_steps, - ) - self._sync_device() - decode_elapsed_s = time.perf_counter() - decode_start - if add_decode_time is not None: - add_decode_time(active_request_ids, decode_elapsed_s) - if enqueue_finished is not None: - enqueue_finished(step_finished) - except Exception as exc: - error = str(exc) - error_request_ids = list(active_request_ids) - if finalize_error is not None: - finalize_error(error_request_ids, error) - active_batch = None - return { - "executed": True, - "active_batch": active_batch, - "request_ids": active_request_ids, - "decode_elapsed_s": float(decode_elapsed_s), - "finished_items": list(step_finished), - "error": error, - "error_request_ids": error_request_ids, - } - - def execute_decode_cycle( - self, - *, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], - add_prefill_time: Callable[[List[str], float], None] | None, - add_merge_time: Callable[[List[str], float], None] | None, - add_decode_time: Callable[[List[str], float], None] | None, - enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, - finalize_error: Callable[[List[str], str], None] | None, - ) -> Dict[str, Any]: - result = { - "executed": False, - "prefill_merge_executed": False, - "decode_step_executed": False, - "active_batch": active_batch, - "prefill_phase": {}, - "decode_phase": {}, - } - prefill_phase = self.execute_prefill_merge( - pending_jobs=list(pending_jobs), - active_batch=result["active_batch"], - mark_prefill_started=mark_prefill_started, - add_prefill_time=add_prefill_time, - add_merge_time=add_merge_time, - enqueue_finished=enqueue_finished, - finalize_error=finalize_error, - ) - prefill_executed = bool(prefill_phase.get("executed", False)) - result["prefill_phase"] = prefill_phase - result["active_batch"] = prefill_phase.get("active_batch") - if prefill_executed: - result["executed"] = True - result["prefill_merge_executed"] = True - decode_phase = self.execute_decode_step( - active_batch=result["active_batch"], - add_decode_time=add_decode_time, - enqueue_finished=enqueue_finished, - finalize_error=finalize_error, - ) - decode_executed = bool(decode_phase.get("executed", False)) - result["decode_phase"] = decode_phase - result["active_batch"] = decode_phase.get("active_batch") - if decode_executed: - result["executed"] = True - result["decode_step_executed"] = True - return result - - -class WorkerDecodeLegacyShell: - def __init__(self, condition: threading.Condition, micro_batch_wait_s: float) -> None: - self.condition = condition - self.micro_batch_wait_s = float(micro_batch_wait_s) - self.pending_jobs: List[SchedulerPendingJob] = [] - self.active_batch: T2SActiveBatch | None = None - - @staticmethod - def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any] | None: - if active_batch is None: - return None - return { - "request_count": int(len(active_batch.request_ids)), - "request_ids": list(active_batch.request_ids), - "prefill_done": bool(active_batch.prefill_done), - "decode_step_index_max": ( - int(active_batch.step_indices.max().item()) - if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0 - else 0 - ), - } - - def current_backlog_locked(self) -> int: - running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids) - return int(len(self.pending_jobs) + running_requests) - - def enqueue_pending_job_locked(self, job: SchedulerPendingJob) -> None: - self.pending_jobs.append(job) - - def snapshot_locked(self) -> Dict[str, Any]: - active_batch_summary = self._summarize_active_batch(self.active_batch) - executor_local_pending_jobs = int(len(self.pending_jobs)) - executor_local_running_requests = 0 if self.active_batch is None else int(len(self.active_batch.request_ids)) - executor_local_has_work = bool(self.pending_jobs or self.active_batch is not None) - return { - "executor_local_pending_jobs": executor_local_pending_jobs, - "executor_local_running_requests": executor_local_running_requests, - "executor_local_has_work": executor_local_has_work, - "executor_local_active_batch": active_batch_summary, - } - - def is_idle_locked(self) -> bool: - return self.active_batch is None and not self.pending_jobs - - def take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - with self.condition: - if not self.pending_jobs and self.active_batch is None: - self.condition.wait(timeout=self.micro_batch_wait_s) - elif wait_for_batch and self.pending_jobs: - self.condition.wait(timeout=self.micro_batch_wait_s) - if not self.pending_jobs: - return [] - pending = list(self.pending_jobs) - self.pending_jobs.clear() - return pending - - def take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - with self.condition: - if not self.pending_jobs: - return [] - if wait_for_batch: - oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) - if (time.perf_counter() - oldest_enqueue_time) < self.micro_batch_wait_s: - return [] - pending = list(self.pending_jobs) - self.pending_jobs.clear() - return pending - - def has_decode_runtime_work(self) -> bool: - with self.condition: - return bool(self.pending_jobs or self.active_batch is not None) - - def build_runtime_summary_locked(self, *, total_cycles: int, prefill_cycles: int, step_cycles: int, last_event: str) -> Dict[str, Any]: - active_request_ids = [] if self.active_batch is None else list(self.active_batch.request_ids) - decode_step_index_max = 0 - prefill_done = False - if self.active_batch is not None: - prefill_done = bool(self.active_batch.prefill_done) - if self.active_batch.step_indices is not None and self.active_batch.step_indices.numel() > 0: - decode_step_index_max = int(self.active_batch.step_indices.max().item()) - return { - "pending_jobs": int(len(self.pending_jobs)), - "active_request_count": int(len(active_request_ids)), - "active_request_ids": active_request_ids[:32], - "prefill_done": bool(prefill_done), - "decode_step_index_max": int(decode_step_index_max), - "total_cycles": int(total_cycles), - "prefill_cycles": int(prefill_cycles), - "step_cycles": int(step_cycles), - "has_work": bool(self.pending_jobs or self.active_batch is not None), - "last_event": str(last_event), - "updated_at": float(time.perf_counter()), - } - - def run_prefill_merge_once_nonblocking( - self, - *, - external_pending_jobs: Optional[List[SchedulerPendingJob]], - external_active_batch: Optional[T2SActiveBatch], - execute_prefill_merge: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], - ) -> Dict[str, Any]: - pending_jobs = ( - list(external_pending_jobs) - if external_pending_jobs is not None - else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) - ) - active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch - result = execute_prefill_merge(pending_jobs, active_batch) - if external_pending_jobs is None: - with self.condition: - self.active_batch = result.get("active_batch") - self.condition.notify_all() - return result - - def run_decode_step_once_nonblocking( - self, - *, - external_active_batch: Optional[T2SActiveBatch], - execute_decode_step: Callable[[Optional[T2SActiveBatch]], Dict[str, Any]], - ) -> Dict[str, Any]: - active_batch = self.active_batch if external_active_batch is None else external_active_batch - result = execute_decode_step(active_batch) - if external_active_batch is None: - with self.condition: - self.active_batch = result.get("active_batch") - self.condition.notify_all() - return result - - def run_decode_cycle_nonblocking( - self, - *, - external_pending_jobs: Optional[List[SchedulerPendingJob]], - external_active_batch: Optional[T2SActiveBatch], - execute_decode_cycle: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], - on_cycle_executed: Callable[[Dict[str, Any]], None] | None, - ) -> Dict[str, Any]: - pending_jobs = ( - list(external_pending_jobs) - if external_pending_jobs is not None - else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) - ) - active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch - result = execute_decode_cycle(pending_jobs, active_batch) - if external_pending_jobs is None: - with self.condition: - self.active_batch = result.get("active_batch") - self.condition.notify_all() - if result.get("executed") and on_cycle_executed is not None: - on_cycle_executed(result) - return result - - def run_loop( - self, - *, - run_decode_cycle_nonblocking: Callable[[], Dict[str, Any]], - ) -> None: - while True: - executed = run_decode_cycle_nonblocking() - if executed.get("executed"): - continue - wait_for_batch = self.active_batch is None - pending_jobs = self.take_pending_snapshot(wait_for_batch=wait_for_batch) - if pending_jobs: - with self.condition: - self.pending_jobs = pending_jobs + self.pending_jobs - self.condition.notify_all() - continue - time.sleep(self.micro_batch_wait_s) - - -class WorkerDecodeRuntimeTracker: - def __init__( - self, - runtime_callbacks: RuntimeStateCallbacks | None = None, - ) -> None: - self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() - self.total_cycles = 0 - self.prefill_cycles = 0 - self.step_cycles = 0 - - def get_counters(self) -> Dict[str, int]: - return { - "total_cycles": int(self.total_cycles), - "prefill_cycles": int(self.prefill_cycles), - "step_cycles": int(self.step_cycles), - } - - def record_cycle(self, result: Dict[str, Any]) -> None: - if not bool(result.get("executed")): - return - self.total_cycles += 1 - if bool(result.get("prefill_merge_executed")): - self.prefill_cycles += 1 - if bool(result.get("decode_step_executed")): - self.step_cycles += 1 - - def build_runtime_summary_locked( - self, - *, - legacy_shell: WorkerDecodeLegacyShell, - last_event: str, - ) -> Dict[str, Any]: - return legacy_shell.build_runtime_summary_locked( - total_cycles=int(self.total_cycles), - prefill_cycles=int(self.prefill_cycles), - step_cycles=int(self.step_cycles), - last_event=str(last_event), - ) - - def notify_runtime_update_locked( - self, - *, - legacy_shell: WorkerDecodeLegacyShell, - last_event: str, - ) -> None: - if self.runtime_callbacks.decode_runtime_update is None: - return - snapshot = self.build_runtime_summary_locked( - legacy_shell=legacy_shell, - last_event=last_event, - ) - self.runtime_callbacks.decode_runtime_update(snapshot) - - -class UnifiedSchedulerWorker: - def __init__( - self, - tts: TTS, - max_steps: int = 1500, - micro_batch_wait_ms: int = 5, - runtime_callbacks: RuntimeStateCallbacks | None = None, - external_finalize_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, - ): - self.tts = tts - self.max_steps = int(max_steps) - self.micro_batch_wait_s = float(micro_batch_wait_ms) / 1000.0 - self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() - self.condition = threading.Condition() - self.completion_bridge = WorkerCompletionBridge(self.runtime_callbacks) - self.decode_executor = WorkerDecodeExecutor(tts, max_steps=max_steps) - self.decode_legacy_shell = WorkerDecodeLegacyShell(self.condition, self.micro_batch_wait_s) - self.decode_runtime_tracker = WorkerDecodeRuntimeTracker(self.runtime_callbacks) - self.prepare_executor = WorkerPrepareExecutor(tts, on_state_change=self._notify_worker_state_change) - self.finalize_executor = WorkerFinalizeExecutor( - tts, - on_state_change=self._notify_worker_state_change, - external_submit=external_finalize_submit, - ) - self.decode_backlog_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_DECODE_BACKLOG_MAX", "0"))) - self.finalize_pending_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_FINALIZE_PENDING_MAX", "0"))) - self.engine_decode_control_enabled = ( - str(os.environ.get("GPTSOVITS_ENGINE_DRIVE_DECODE", "0")).strip().lower() in {"1", "true", "yes", "on"} - ) - self.job_registry = SchedulerJobRegistry(self.condition) - self.worker_thread: threading.Thread | None = None - if not self.engine_decode_control_enabled: - self.worker_thread = threading.Thread(target=self._run_loop, name="unified-t2s-scheduler-worker", daemon=True) - self.worker_thread.start() - self.finalize_threads = [] - if external_finalize_submit is None: - self.finalize_threads = [ - threading.Thread( - target=self._run_finalize_loop, - name=f"unified-t2s-finalize-{worker_index}", - daemon=True, - ) - for worker_index in range(self.finalize_executor.get_worker_count()) - ] - for finalize_thread in self.finalize_threads: - finalize_thread.start() - - def _notify_worker_state_change(self) -> None: - with self.condition: - self.condition.notify_all() - - def _current_decode_backlog_locked(self) -> int: - return self.decode_legacy_shell.current_backlog_locked() - - def get_micro_batch_wait_s(self) -> float: - return float(self.micro_batch_wait_s) - - def is_engine_decode_control_enabled(self) -> bool: - return bool(self.engine_decode_control_enabled) - - def get_prepare_max_inflight(self) -> int: - return int(self.prepare_executor.get_max_inflight()) - - def get_capacity_limits(self) -> Dict[str, int]: - return { - "decode_backlog_max": int(self.decode_backlog_max), - "finalize_pending_max": int(self.finalize_pending_max), - } - - def get_finalize_batch_policy(self) -> Dict[str, Any]: - return dict(self.finalize_executor.get_batch_policy()) - - def get_decode_runtime_counters(self) -> Dict[str, int]: - with self.condition: - return self.decode_runtime_tracker.get_counters() - - def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]: - decode_backlog = self._current_decode_backlog_locked() - finalize_pending = int(self.finalize_executor.get_pending_count()) - prepare_inflight = int(self.prepare_executor.snapshot()["inflight"]) - blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max - blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max - return ( - not blocked_decode and not blocked_finalize, - { - "decode_backlog": decode_backlog, - "finalize_pending": finalize_pending, - "prepare_inflight": prepare_inflight, - "decode_backlog_max": int(self.decode_backlog_max), - "finalize_pending_max": int(self.finalize_pending_max), - }, - ) - - def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]: - start = time.perf_counter() - deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) - last_snapshot: Dict[str, int] = {} - while True: - with self.condition: - allowed, snapshot = self._can_accept_submit_locked() - last_snapshot = snapshot - if allowed: - return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot - if deadline is not None and time.perf_counter() >= deadline: - raise TimeoutError( - "scheduler submit admission timeout " - f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})" - ) - self.condition.wait(timeout=self.micro_batch_wait_s) - - def _admission_snapshot_locked(self) -> Dict[str, int]: - _, snapshot = self._can_accept_submit_locked() - return snapshot - - async def submit_async( - self, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None = None, - done_future: asyncio.Future | None = None, - engine_request_id: str | None = None, - timeout_sec: float | None = None, - skip_capacity_wait: bool = False, - admission_wait_ms_override: float | None = None, - admission_snapshot_override: Dict[str, Any] | None = None, - engine_policy_wait_ms: float = 0.0, - engine_dispatch_wait_ms: float = 0.0, - enqueue_pending: bool = True, - ) -> SchedulerPendingJob: - return await asyncio.to_thread( - self.submit, - state, - speed_factor, - sample_steps, - media_type, - prepare_wall_ms, - prepare_profile_total_ms, - done_loop, - done_future, - engine_request_id, - timeout_sec, - skip_capacity_wait, - admission_wait_ms_override, - admission_snapshot_override, - engine_policy_wait_ms, - engine_dispatch_wait_ms, - enqueue_pending, - ) - - def snapshot(self) -> dict: - with self.condition: - prepare_state = self.prepare_executor.snapshot() - finalize_state = self.finalize_executor.snapshot() - shell_state = self.decode_legacy_shell.snapshot_locked() - decode_runtime_counters = self.decode_runtime_tracker.get_counters() - engine_owned_decode_state = bool(self.engine_decode_control_enabled) - active_batch_summary = shell_state.get("executor_local_active_batch") - executor_local_pending_jobs = int(shell_state.get("executor_local_pending_jobs", 0)) - executor_local_running_requests = int(shell_state.get("executor_local_running_requests", 0)) - executor_local_has_work = bool(shell_state.get("executor_local_has_work", False)) - return { - "pending_jobs": 0 if engine_owned_decode_state else executor_local_pending_jobs, - "running_requests": 0 if engine_owned_decode_state else executor_local_running_requests, - "engine_decode_control_enabled": bool(self.engine_decode_control_enabled), - "legacy_state_owner_mode": not engine_owned_decode_state, - "decode_state_owner": "engine" if engine_owned_decode_state else "worker", - "decode_runtime_has_work": False if engine_owned_decode_state else executor_local_has_work, - "executor_local_pending_jobs": executor_local_pending_jobs, - "executor_local_running_requests": executor_local_running_requests, - "executor_local_has_work": executor_local_has_work, - "decode_runtime_total_cycles": int(decode_runtime_counters.get("total_cycles", 0)), - "decode_runtime_prefill_cycles": int(decode_runtime_counters.get("prefill_cycles", 0)), - "decode_runtime_step_cycles": int(decode_runtime_counters.get("step_cycles", 0)), - "prepare_inflight": prepare_state["inflight"], - "prepare_peak_inflight": prepare_state["peak_inflight"], - "prepare_max_inflight": prepare_state.get("max_inflight", 0), - "prepare_state": dict(prepare_state), - **finalize_state, - "decode_backlog_max": self.decode_backlog_max, - "finalize_pending_max": self.finalize_pending_max, - "active_batch": {} if engine_owned_decode_state else active_batch_summary, - "executor_local_active_batch": active_batch_summary if engine_owned_decode_state else None, - "total_submitted": self.job_registry.submitted_count(), - "total_finished": self.job_registry.finished_count(), - "drained": self.is_drained(), - } - - def is_drained(self) -> bool: - with self.condition: - return ( - self.decode_legacy_shell.is_idle_locked() - and self.job_registry.is_empty() - and self.prepare_executor.is_idle() - and self.finalize_executor.is_idle() - ) - - def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: - deadline = time.perf_counter() + max(0.0, timeout_sec) - while time.perf_counter() < deadline: - if self.is_drained(): - return True - time.sleep(poll_interval_sec) - return self.is_drained() - - def submit( - self, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None = None, - done_future: asyncio.Future | None = None, - engine_request_id: str | None = None, - timeout_sec: float | None = None, - skip_capacity_wait: bool = False, - admission_wait_ms_override: float | None = None, - admission_snapshot_override: Dict[str, Any] | None = None, - engine_policy_wait_ms: float = 0.0, - engine_dispatch_wait_ms: float = 0.0, - enqueue_pending: bool = True, - ) -> SchedulerPendingJob: - if skip_capacity_wait: - with self.condition: - admission_snapshot = ( - dict(admission_snapshot_override) - if admission_snapshot_override is not None - else dict(self._admission_snapshot_locked()) - ) - admission_wait_ms = 0.0 if admission_wait_ms_override is None else float(admission_wait_ms_override) - else: - admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec) - job = SchedulerPendingJob( - request_id=state.request_id, - state=state, - done_event=threading.Event(), - done_loop=done_loop, - done_future=done_future, - enqueue_time=time.perf_counter(), - speed_factor=float(speed_factor), - sample_steps=int(sample_steps), - media_type=media_type, - admission_wait_ms=float(admission_wait_ms), - engine_policy_wait_ms=float(engine_policy_wait_ms), - engine_dispatch_wait_ms=float(engine_dispatch_wait_ms), - prepare_wall_ms=float(prepare_wall_ms), - prepare_profile_total_ms=float(prepare_profile_total_ms), - engine_request_id=engine_request_id or state.request_id, - ) - with self.condition: - self.job_registry.register(job, keep_job=not self.engine_decode_control_enabled) - if enqueue_pending: - self.decode_legacy_shell.enqueue_pending_job_locked(job) - self.condition.notify_all() - if enqueue_pending: - self._notify_decode_runtime_state("submit") - self._runtime_update( - job.engine_request_id, - EngineStatus.QUEUED, - { - "scheduler_request_id": job.request_id, - "decode_admission_wait_ms": float(admission_wait_ms), - "engine_policy_wait_ms": float(engine_policy_wait_ms), - "engine_dispatch_wait_ms": float(engine_dispatch_wait_ms), - "admission_snapshot": dict(admission_snapshot), - }, - ) - return job - - async def prepare_state_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> tuple[T2SRequestState, float, float]: - return await self.prepare_executor.prepare_state_profiled_async(spec, prepare_submit_at) - - async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: - return await self.prepare_executor.prepare_states_batch_async(specs) - - async def prepare_cpu_stage_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> PreparedCpuStage: - return await self.prepare_executor.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) - - async def prepare_gpu_stage_profiled_async( - self, - cpu_stage: PreparedCpuStage, - ) -> tuple[T2SRequestState, float, float]: - return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage) - - def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: - with self.condition: - for job in pending_jobs: - job.first_schedule_time = float(started_at) - self._runtime_update( - job.engine_request_id, - EngineStatus.GPU_PREPARING, - {"scheduler_request_id": job.request_id, "prefill_started_at": float(started_at)}, - ) - - def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - job.prefill_ms += delta_ms - - def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - job.merge_ms += delta_ms - - def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - activate_request_ids: List[str] = [] - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - if job.decode_steps == 0: - activate_request_ids.append(job.engine_request_id) - job.decode_ms += delta_ms - job.decode_steps += 1 - for engine_request_id in activate_request_ids: - self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None) - - def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - job.finalize_wait_ms += float(delta_ms) - - def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: - if not items: - return - enqueued_at = time.perf_counter() - tasks: List[SchedulerFinalizeTask] = [] - with self.condition: - for item in items: - job = self.job_registry.get(item.request_id) - if job is not None: - self._runtime_update( - job.engine_request_id, - EngineStatus.READY_FOR_FINALIZE, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - }, - ) - tasks.append(SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at)) - self.finalize_executor.enqueue_tasks(tasks) - - def begin_finalize_execution(self, task_count: int) -> None: - self.finalize_executor.begin_execution(task_count) - - def end_finalize_execution(self, task_count: int) -> None: - self.finalize_executor.end_execution(task_count) - - def record_external_job_done(self, request_id: str) -> None: - with self.condition: - self.job_registry.mark_finished_and_remove(request_id) - self.condition.notify_all() - - def synthesize_finalize_jobs( - self, - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], - ) -> tuple[float, List[tuple[int, np.ndarray]]]: - return self.finalize_executor.synthesize_finalize_jobs(jobs_and_items) - - def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: - self.completion_bridge.complete_finalize_task( - condition=self.condition, - job_registry=self.job_registry, - job=job, - item=item, - sample_rate=sample_rate, - audio_data=audio_data, - ) - - def _finalize_error(self, request_ids: List[str], error: str) -> None: - self.completion_bridge.fail_jobs( - condition=self.condition, - job_registry=self.job_registry, - request_ids=request_ids, - error=error, - ) - - @staticmethod - def _resolve_done_future(job: SchedulerPendingJob) -> None: - future = job.done_future - if future is None or future.done(): - return - future.set_result(job) - - def _notify_done_future(self, job: SchedulerPendingJob) -> None: - self.completion_bridge.notify_done_future(job) - - def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None: - if request_id is None or self.runtime_callbacks.update is None: - return - self.runtime_callbacks.update(request_id, status, extra) - - def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: - self.completion_bridge.runtime_complete(request_id, extra) - - def _runtime_fail(self, request_id: str | None, error: str) -> None: - self.completion_bridge.runtime_fail(request_id, error) - - def _build_decode_runtime_summary_locked(self, last_event: str) -> Dict[str, Any]: - return self.decode_runtime_tracker.build_runtime_summary_locked( - legacy_shell=self.decode_legacy_shell, - last_event=str(last_event), - ) - - def _notify_decode_runtime_state(self, last_event: str) -> None: - with self.condition: - self.decode_runtime_tracker.notify_runtime_update_locked( - legacy_shell=self.decode_legacy_shell, - last_event=str(last_event), - ) - - def _record_decode_runtime_cycle(self, result: Dict[str, Any]) -> None: - with self.condition: - self.decode_runtime_tracker.record_cycle(result) - - def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.decode_legacy_shell.take_pending_snapshot(wait_for_batch) - - def _take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.decode_legacy_shell.take_pending_snapshot_nonblocking(wait_for_batch) - - def has_decode_runtime_work(self) -> bool: - return self.decode_legacy_shell.has_decode_runtime_work() - - def execute_prefill_merge( - self, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - return self.decode_executor.execute_prefill_merge( - pending_jobs=pending_jobs, - active_batch=active_batch, - mark_prefill_started=self._mark_prefill_started, - add_prefill_time=None if external_bookkeeping else self._add_prefill_time, - add_merge_time=None if external_bookkeeping else self._add_merge_time, - enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, - finalize_error=None if external_bookkeeping else self._finalize_error, - ) - - def execute_decode_step( - self, - active_batch: Optional[T2SActiveBatch], - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - return self.decode_executor.execute_decode_step( - active_batch=active_batch, - add_decode_time=None if external_bookkeeping else self._add_decode_time, - enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, - finalize_error=None if external_bookkeeping else self._finalize_error, - ) - - def execute_decode_cycle( - self, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_executor.execute_decode_cycle( - pending_jobs=pending_jobs, - active_batch=active_batch, - mark_prefill_started=self._mark_prefill_started, - add_prefill_time=None if external_bookkeeping else self._add_prefill_time, - add_merge_time=None if external_bookkeeping else self._add_merge_time, - add_decode_time=None if external_bookkeeping else self._add_decode_time, - enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, - finalize_error=None if external_bookkeeping else self._finalize_error, - ) - self._record_decode_runtime_cycle(result) - return result - - def run_prefill_merge_once_nonblocking( - self, - external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, - external_active_batch: Optional[T2SActiveBatch] = None, - emit_runtime_state: bool = True, - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_legacy_shell.run_prefill_merge_once_nonblocking( - external_pending_jobs=external_pending_jobs, - external_active_batch=external_active_batch, - execute_prefill_merge=lambda batch_jobs, batch_state: self.execute_prefill_merge( - pending_jobs=batch_jobs, - active_batch=batch_state, - external_bookkeeping=external_bookkeeping, - ), - ) - if emit_runtime_state: - self._notify_decode_runtime_state("prefill_merge") - return result - - def run_decode_step_once_nonblocking( - self, - external_active_batch: Optional[T2SActiveBatch] = None, - emit_runtime_state: bool = True, - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_legacy_shell.run_decode_step_once_nonblocking( - external_active_batch=external_active_batch, - execute_decode_step=lambda batch_state: self.execute_decode_step( - active_batch=batch_state, - external_bookkeeping=external_bookkeeping, - ), - ) - if emit_runtime_state: - self._notify_decode_runtime_state("decode_step") - return result - - def run_decode_cycle_nonblocking( - self, - external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, - external_active_batch: Optional[T2SActiveBatch] = None, - emit_runtime_state: bool = True, - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_legacy_shell.run_decode_cycle_nonblocking( - external_pending_jobs=external_pending_jobs, - external_active_batch=external_active_batch, - execute_decode_cycle=lambda batch_jobs, batch_state: self.execute_decode_cycle( - pending_jobs=batch_jobs, - active_batch=batch_state, - external_bookkeeping=external_bookkeeping, - ), - on_cycle_executed=None, - ) - if result.get("executed") and emit_runtime_state: - self._notify_decode_runtime_state("decode_cycle") - return result - - def execute_finalize_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: - if not tasks: - return - try: - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] - with self.condition: - for task in tasks: - job = self.job_registry.get(task.request_id) - if job is None: - continue - jobs_and_items.append((job, task.item)) - if not jobs_and_items: - return - now = time.perf_counter() - for task in tasks: - self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) - for job, item in jobs_and_items: - self._runtime_update( - job.engine_request_id, - EngineStatus.FINALIZING, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - }, - ) - synth_ms, batch_results = self.synthesize_finalize_jobs(jobs_and_items) - with self.condition: - for job, _ in jobs_and_items: - tracked_job = self.job_registry.get(job.request_id) - if tracked_job is not None: - tracked_job.synth_ms += synth_ms - for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): - self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) - except Exception as exc: - self._finalize_error([task.request_id for task in tasks], str(exc)) - finally: - self.finalize_executor.end_execution(len(tasks)) - - def _run_finalize_loop(self) -> None: - while True: - tasks = self.finalize_executor.take_task_batch_blocking() - self.execute_finalize_tasks(tasks) - - def _run_loop(self) -> None: - self.decode_legacy_shell.run_loop( - run_decode_cycle_nonblocking=lambda: self.run_decode_cycle_nonblocking() - ) - - -def set_scheduler_seed(seed: int): - if seed in ["", None]: - return - seed = int(seed) - if seed < 0: - return - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): - def handle_pack_ogg(): - with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: - audio_file.write(data) - - stack_size = 4096 * 4096 - try: - threading.stack_size(stack_size) - pack_ogg_thread = threading.Thread(target=handle_pack_ogg) - pack_ogg_thread.start() - pack_ogg_thread.join() - except (RuntimeError, ValueError): - handle_pack_ogg() - return io_buffer - - -def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): - io_buffer.write(data.tobytes()) - return io_buffer - - -def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): - io_buffer = BytesIO() - sf.write(io_buffer, data, rate, format="wav") - return io_buffer - - -def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): - process = subprocess.Popen( - [ - "ffmpeg", - "-f", - "s16le", - "-ar", - str(rate), - "-ac", - "1", - "-i", - "pipe:0", - "-c:a", - "aac", - "-b:a", - "192k", - "-vn", - "-f", - "adts", - "pipe:1", - ], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - out, _ = process.communicate(input=data.tobytes()) - io_buffer.write(out) - return io_buffer - - -def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): - if media_type == "ogg": - io_buffer = pack_ogg(io_buffer, data, rate) - elif media_type == "aac": - io_buffer = pack_aac(io_buffer, data, rate) - elif media_type == "wav": - io_buffer = pack_wav(io_buffer, data, rate) - else: - io_buffer = pack_raw(io_buffer, data, rate) - io_buffer.seek(0) - return io_buffer - - -def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): - wav_buf = BytesIO() - with wave.open(wav_buf, "wb") as vfout: - vfout.setnchannels(channels) - vfout.setsampwidth(sample_width) - vfout.setframerate(sample_rate) - vfout.writeframes(frame_input) - wav_buf.seek(0) - return wav_buf.read() - - -class UnifiedTTSEngine: +class UnifiedTTSEngine(EngineBridgeDelegates, EngineApiDelegates, EngineRuntimeDelegates): @staticmethod def _env_flag(name: str, default: bool) -> bool: value = os.environ.get(name) @@ -2783,2048 +42,4 @@ class UnifiedTTSEngine: self.tts = tts self.cut_method_names = set(cut_method_names) self.control_callbacks = control_callbacks or RuntimeControlCallbacks() - self.reference_registry = ReferenceRegistry() - self.model_registry = ModelRegistry( - t2s_weights_path=str(self.tts.configs.t2s_weights_path), - vits_weights_path=str(self.tts.configs.vits_weights_path), - ) - self.request_registry = EngineRequestRegistry( - recent_limit=max(1, int(os.environ.get("GPTSOVITS_ENGINE_RECENT_REQUEST_LIMIT", "64"))) - ) - self.engine_job_registry = SchedulerJobRegistry(threading.Lock()) - self.scheduler_worker = UnifiedSchedulerWorker( - tts, - max_steps=max_steps, - micro_batch_wait_ms=micro_batch_wait_ms, - runtime_callbacks=RuntimeStateCallbacks( - update=self._update_request_state, - complete=self._complete_request_state, - fail=self._fail_request_state, - decode_runtime_update=self._update_engine_decode_runtime_state, - ), - external_finalize_submit=self._enqueue_worker_finished_for_finalize, - ) - self.direct_tts_lock = threading.RLock() - self.management_lock = threading.RLock() - worker_capacity_limits = self.scheduler_worker.get_capacity_limits() - prepare_max_inflight = int(self.scheduler_worker.get_prepare_max_inflight()) - self.engine_policy_config = EnginePolicyConfig( - enabled=self._env_flag("GPTSOVITS_ENGINE_POLICY_ENABLE", True), - poll_wait_ms=max(1.0, self._env_float("GPTSOVITS_ENGINE_POLICY_POLL_WAIT_MS", float(micro_batch_wait_ms))), - decode_backlog_soft_max=max( - 0, - self._env_int( - "GPTSOVITS_ENGINE_POLICY_DECODE_BACKLOG_SOFT_MAX", - int(worker_capacity_limits["decode_backlog_max"]), - ), - ), - finalize_pending_soft_max=max( - 0, - self._env_int( - "GPTSOVITS_ENGINE_POLICY_FINALIZE_PENDING_SOFT_MAX", - int(worker_capacity_limits["finalize_pending_max"]), - ), - ), - prepare_inflight_soft_max=max( - 0, - self._env_int("GPTSOVITS_ENGINE_POLICY_PREPARE_INFLIGHT_SOFT_MAX", prepare_max_inflight), - ), - active_decode_soft_max=max(0, self._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_DECODE_SOFT_MAX", 0)), - ready_for_prefill_soft_max=max(0, self._env_int("GPTSOVITS_ENGINE_POLICY_READY_FOR_PREFILL_SOFT_MAX", 0)), - active_request_soft_max=max(0, self._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_REQUEST_SOFT_MAX", 0)), - ) - self.engine_arbiter_config = EngineArbiterConfig( - poll_wait_ms=max(1.0, self._env_float("GPTSOVITS_ENGINE_ARBITER_POLL_WAIT_MS", float(micro_batch_wait_ms))), - decode_burst=max(1, self._env_int("GPTSOVITS_ENGINE_ARBITER_DECODE_BURST", 4)), - prepare_aging_ms=max(0.0, self._env_float("GPTSOVITS_ENGINE_ARBITER_PREPARE_AGING_MS", 10.0)), - finalize_aging_ms=max(0.0, self._env_float("GPTSOVITS_ENGINE_ARBITER_FINALIZE_AGING_MS", 10.0)), - ) - self.engine_decode_runtime_owner = EngineDecodeRuntimeOwner( - get_decode_runtime_counters=self.scheduler_worker.get_decode_runtime_counters, - get_micro_batch_wait_s=self.scheduler_worker.get_micro_batch_wait_s, - ) - self.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") - self.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") - self.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched") - self.engine_dispatch_last_snapshot: Dict[str, Any] = {} - self.engine_policy_arbiter = EnginePolicyArbiterController( - policy_config=self.engine_policy_config, - arbiter_config=self.engine_arbiter_config, - snapshot_request_registry=self._snapshot_request_registry, - get_worker_state=self.get_scheduler_state, - snapshot_prepare_state=self._snapshot_engine_prepare_state, - snapshot_finalize_state=self._snapshot_engine_finalize_state, - snapshot_dispatch_state=self._snapshot_engine_dispatch_state, - snapshot_decode_runtime_state=self._snapshot_engine_decode_runtime_state, - snapshot_job_registry=self._snapshot_engine_job_registry, - peek_queue_age_ms=self._peek_queue_age_ms, - merge_request_state_profile=self._merge_request_state_profile, - ) - self.engine_arbiter_thread = threading.Thread( - target=self._run_engine_arbiter_loop, - name="unified-engine-arbiter", - daemon=True, - ) - self.engine_arbiter_thread.start() - - def _register_request_state( - self, - request_id: str, - api_mode: str, - backend: str, - media_type: str, - response_streaming: bool, - deadline_ts: float | None = None, - meta: Optional[Dict[str, Any]] = None, - ) -> EngineRequestState: - return self.request_registry.register( - request_id=request_id, - api_mode=api_mode, - backend=backend, - media_type=media_type, - response_streaming=response_streaming, - deadline_ts=deadline_ts, - meta=meta, - ) - - def _update_request_state( - self, - request_id: str, - status: str, - extra: Optional[Dict[str, Any]] = None, - ) -> None: - self.request_registry.update(request_id, status, extra) - - def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.request_registry.merge_profile(request_id, extra) - - def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: - return self.engine_prepare_queue_owner.snapshot(max_request_ids=16) - - def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: - return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) - - def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: - return self.engine_dispatch_queue_owner.snapshot( - max_request_ids=16, - extra={"last_policy_snapshot": dict(self.engine_dispatch_last_snapshot or {})}, - ) - - def _register_engine_job(self, job: SchedulerPendingJob) -> None: - self.engine_job_registry.register(job, keep_job=True) - - def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: - return self.engine_job_registry.get(request_id) - - def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: - return self.engine_job_registry.pop(request_id) - - def _snapshot_engine_job_registry(self) -> Dict[str, Any]: - return self.engine_job_registry.snapshot(max_request_ids=32) - - def _is_engine_drained(self) -> bool: - prepare_empty = self.engine_prepare_queue_owner.is_drained() - dispatch_empty = self.engine_dispatch_queue_owner.is_drained() - finalize_empty = self.engine_finalize_queue_owner.is_drained() - decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() - job_empty = self.engine_job_registry.is_empty() - worker_state = self.scheduler_worker.snapshot() - return bool( - prepare_empty - and dispatch_empty - and finalize_empty - and decode_pending_empty - and job_empty - and self.engine_decode_runtime_owner.get_active_batch() is None - and int(worker_state.get("prepare_inflight", 0)) <= 0 - and int(worker_state.get("finalize_inflight", 0)) <= 0 - and int(worker_state.get("finalize_pending", 0)) <= 0 - ) - - def _record_engine_job_done(self, request_id: str) -> None: - self.engine_job_registry.mark_finished_and_remove(request_id) - self.scheduler_worker.record_external_job_done(request_id) - - def _complete_engine_job( - self, - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - audio_data: np.ndarray, - ) -> None: - completion_bridge = self.scheduler_worker.completion_bridge - completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) - completion_bridge.complete_job( - job, - runtime_request_id=job.engine_request_id, - runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate), - on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid), - ) - - def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: - if not request_ids: - return - completion_bridge = self.scheduler_worker.completion_bridge - for request_id in request_ids: - job = self._get_engine_job(request_id) - if job is None: - continue - completion_bridge.fail_job( - job, - error=error, - on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid), - ) - - def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - for job in jobs: - job.prefill_ms += delta_ms - - def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - for request_id in request_ids: - job = self._get_engine_job(request_id) - if job is not None: - job.merge_ms += delta_ms - - def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - activate_request_ids: List[str] = [] - for request_id in request_ids: - job = self._get_engine_job(request_id) - if job is None: - continue - if job.decode_steps == 0: - activate_request_ids.append(job.engine_request_id) - job.decode_ms += delta_ms - job.decode_steps += 1 - for engine_request_id in activate_request_ids: - self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None) - - def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: - if not items: - return - enqueued_at = time.perf_counter() - tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items] - self._enqueue_worker_finished_for_finalize(tasks) - - def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: - return self.engine_decode_runtime_owner.snapshot_pending_queue_state() - - @staticmethod - def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: - return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch) - - def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: - self.engine_decode_runtime_owner.refresh_state(last_event) - - def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: - if not snapshot: - return - if self.scheduler_worker.is_engine_decode_control_enabled(): - return - self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot) - - def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: - return self.engine_decode_runtime_owner.snapshot_state() - - def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: - return self.engine_policy_arbiter.snapshot_state() - - def _notify_engine_arbiter(self) -> None: - self.engine_policy_arbiter.notify() - - def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: - self.engine_decode_runtime_owner.enqueue_pending_job(job) - self._notify_engine_arbiter() - - def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.engine_decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch) - - def _peek_queue_age_ms(self, queue_name: str) -> float: - if queue_name == "prepare": - return self.engine_prepare_queue_owner.peek_oldest_age_ms("enqueue_time") - elif queue_name == "finalize": - return self.engine_finalize_queue_owner.peek_oldest_age_ms("enqueued_time") - elif queue_name == "decode_runtime_pending": - return self.engine_decode_runtime_owner.pending_age_ms() - else: - return self.engine_dispatch_queue_owner.peek_oldest_age_ms("enqueue_time") - - def _engine_has_pending_work(self) -> bool: - if self.scheduler_worker.is_engine_decode_control_enabled(): - if self.engine_decode_runtime_owner.has_pending_jobs(): - return True - if self.scheduler_worker.is_engine_decode_control_enabled() and self._snapshot_engine_decode_runtime_state().get("active_request_count", 0) > 0: - return True - if self.engine_prepare_queue_owner.has_items(): - return True - if self.engine_finalize_queue_owner.has_items(): - return True - return self.engine_dispatch_queue_owner.has_items() - - @staticmethod - def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None: - if future.done(): - return - future.set_exception(error) - - def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) - except RuntimeError: - pass - - @staticmethod - def _resolve_prepare_future( - future: asyncio.Future, - payload: tuple[T2SRequestState, float, float], - ) -> None: - if future.done(): - return - future.set_result(payload) - - def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) - except RuntimeError: - pass - - def _notify_prepare_result( - self, - task: EngineGpuPrepareTask, - payload: tuple[T2SRequestState, float, float], - ) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload) - except RuntimeError: - pass - - async def _prepare_state_via_engine_gpu_queue( - self, - *, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - engine_request_id: str | None, - ) -> tuple[T2SRequestState, float, float]: - cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) - if engine_request_id not in [None, ""]: - self._update_request_state( - str(engine_request_id), - EngineStatus.GPU_PREPARING, - { - "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), - "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), - "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), - "text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms), - }, - ) - loop = asyncio.get_running_loop() - done_future = loop.create_future() - task = EngineGpuPrepareTask( - request_id=spec.request_id, - cpu_stage=cpu_stage, - done_loop=loop, - done_future=done_future, - engine_request_id=engine_request_id or spec.request_id, - enqueue_time=time.perf_counter(), - ) - self.engine_prepare_queue_owner.enqueue(task) - self._notify_engine_arbiter() - state, prepare_exec_started_at, prepare_exec_finished_at = await done_future - return state, prepare_exec_started_at, prepare_exec_finished_at - - def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: - if not tasks: - return - for task in tasks: - job = self._get_engine_job(task.request_id) - if job is not None: - self._update_request_state( - job.engine_request_id, - EngineStatus.READY_FOR_FINALIZE, - { - "finish_reason": task.item.finish_reason, - "semantic_len": int(task.item.semantic_tokens.shape[0]), - "finish_idx": int(task.item.finish_idx), - }, - ) - self.engine_finalize_queue_owner.enqueue_many(tasks) - self._notify_engine_arbiter() - - def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: - finalize_policy = self.scheduler_worker.get_finalize_batch_policy() - return self.engine_finalize_queue_owner.take_finalize_batch( - finalize_mode=str(finalize_policy.get("finalize_mode", "async")), - batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)), - batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)), - use_vocoder=bool(self.tts.configs.use_vocoder), - ) - - async def _enqueue_prepared_state_for_dispatch( - self, - *, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None, - done_future: asyncio.Future | None, - engine_request_id: str | None, - timeout_sec: float | None, - ) -> EngineDispatchTask: - task = EngineDispatchTask( - request_id=state.request_id, - state=state, - speed_factor=float(speed_factor), - sample_steps=int(sample_steps), - media_type=media_type, - prepare_wall_ms=float(prepare_wall_ms), - prepare_profile_total_ms=float(prepare_profile_total_ms), - done_loop=done_loop, - done_future=done_future, - engine_request_id=engine_request_id or state.request_id, - timeout_sec=timeout_sec, - enqueue_time=time.perf_counter(), - ) - self.engine_dispatch_queue_owner.enqueue(task) - self._notify_engine_arbiter() - self._merge_request_state_profile( - task.engine_request_id or task.request_id, - { - "engine_dispatch_queue_depth_on_enqueue": int(self._snapshot_engine_dispatch_state()["waiting_count"]), - }, - ) - return task - - def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: - self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) - - def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: - stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage() - self.engine_dispatch_last_snapshot = dict(policy_snapshot) - return stage, reason, policy_snapshot, worker_state - - def _run_engine_prepare_once(self) -> bool: - task = self.engine_prepare_queue_owner.pop_left() - if task is None: - return False - queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) - try: - state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( - self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) - ) - state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) - if task.engine_request_id not in [None, ""]: - self._merge_request_state_profile( - str(task.engine_request_id), - {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, - ) - self.engine_prepare_queue_owner.mark_completed(1) - self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) - return True - except Exception as exc: - task.error = str(exc) - self._fail_request_state(task.engine_request_id or task.request_id, str(exc)) - self._notify_prepare_error(task, exc) - return True - - def _run_engine_finalize_once(self) -> bool: - tasks = self._take_engine_finalize_batch_nonblocking() - if not tasks: - return False - self.scheduler_worker.begin_finalize_execution(len(tasks)) - try: - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] - for task in tasks: - job = self._get_engine_job(task.request_id) - if job is None: - continue - jobs_and_items.append((job, task.item)) - if not jobs_and_items: - return False - now = time.perf_counter() - for task in tasks: - job = self._get_engine_job(task.request_id) - if job is not None: - job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) - for job, item in jobs_and_items: - self._update_request_state( - job.engine_request_id, - EngineStatus.FINALIZING, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - }, - ) - synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items) - for job, _ in jobs_and_items: - job.synth_ms += float(synth_ms) - for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): - self._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) - except Exception as exc: - self._fail_engine_jobs([task.request_id for task in tasks], str(exc)) - finally: - self.scheduler_worker.end_finalize_execution(len(tasks)) - self.engine_finalize_queue_owner.mark_completed(len(tasks), notify=True) - return True - - def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: - if not bool(policy_snapshot.get("allowed", True)): - return False - dispatch_task = self.engine_dispatch_queue_owner.pop_left() - if dispatch_task is None: - return False - dispatched_at = time.perf_counter() - dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0) - dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms) - dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms) - dispatch_task.engine_policy_snapshot = dict(policy_snapshot) - try: - worker_job = self.scheduler_worker.submit( - state=dispatch_task.state, - speed_factor=dispatch_task.speed_factor, - sample_steps=dispatch_task.sample_steps, - media_type=dispatch_task.media_type, - prepare_wall_ms=dispatch_task.prepare_wall_ms, - prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, - done_loop=dispatch_task.done_loop, - done_future=dispatch_task.done_future, - engine_request_id=dispatch_task.engine_request_id, - timeout_sec=dispatch_task.timeout_sec, - skip_capacity_wait=True, - admission_wait_ms_override=0.0, - admission_snapshot_override=dict(worker_state), - engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms, - engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms, - enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(), - ) - dispatch_task.worker_job = worker_job - self._register_engine_job(worker_job) - if self.scheduler_worker.is_engine_decode_control_enabled(): - self._enqueue_engine_decode_pending_job(worker_job) - self.engine_dispatch_queue_owner.mark_completed(1) - return True - except Exception as exc: - dispatch_task.error = str(exc) - self._fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc)) - self._notify_dispatch_error(dispatch_task, exc) - return True - - def _run_engine_decode_runtime_once(self) -> bool: - if not self.scheduler_worker.is_engine_decode_control_enabled(): - return False - runtime_state = self._snapshot_engine_decode_runtime_state() - pending_jobs = self._take_engine_decode_pending_jobs_nonblocking( - wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0 - ) - result = self.scheduler_worker.execute_decode_cycle( - pending_jobs=pending_jobs, - active_batch=self.engine_decode_runtime_owner.get_active_batch(), - external_bookkeeping=True, - ) - prefill_phase = dict(result.get("prefill_phase") or {}) - if prefill_phase.get("error"): - self._fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error"))) - else: - prefill_jobs = list(prefill_phase.get("pending_jobs") or []) - self._add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0))) - self._add_engine_merge_time( - [] if result.get("active_batch") is None else list(result["active_batch"].request_ids), - float(prefill_phase.get("merge_elapsed_s", 0.0)), - ) - self._enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or [])) - decode_phase = dict(result.get("decode_phase") or {}) - if decode_phase.get("error"): - self._fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error"))) - else: - self._add_engine_decode_time( - list(decode_phase.get("request_ids") or []), - float(decode_phase.get("decode_elapsed_s", 0.0)), - ) - self._enqueue_engine_finished_items(list(decode_phase.get("finished_items") or [])) - self.engine_decode_runtime_owner.set_active_batch(result.get("active_batch")) - if result.get("executed", False): - self._refresh_engine_decode_runtime_state("engine_decode_cycle") - return bool(result.get("executed", False)) - - def _run_engine_arbiter_loop(self) -> None: - while True: - if not self._engine_has_pending_work(): - self._mark_arbiter_tick(stage="idle", reason="no_pending_work", policy_allowed=True) - self.engine_policy_arbiter.wait() - continue - stage, reason, policy_snapshot, worker_state = self._select_engine_stage() - policy_allowed = bool(policy_snapshot.get("allowed", True)) - executed = False - if stage == "prepare": - executed = self._run_engine_prepare_once() - elif stage == "finalize": - executed = self._run_engine_finalize_once() - elif stage == "decode_dispatch": - executed = self._run_engine_dispatch_once(policy_snapshot, worker_state) - elif stage == "decode_runtime": - executed = self._run_engine_decode_runtime_once() - if not executed: - self._mark_arbiter_tick(stage="idle", reason=f"{stage}_not_ready", policy_allowed=policy_allowed) - self.engine_policy_arbiter.wait() - continue - self._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) - - def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.request_registry.complete(request_id, extra) - - def _fail_request_state(self, request_id: str, error: str) -> None: - self.request_registry.fail(request_id, error) - - def _snapshot_request_registry(self) -> Dict[str, Any]: - return self.request_registry.snapshot() - - @staticmethod - def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: - if component is None or not hasattr(component, "snapshot"): - return None - try: - return dict(component.snapshot()) - except Exception: - return None - - def _build_stage_counters( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - return self.engine_policy_arbiter.build_stage_counters(request_registry, worker_state) - - def _build_engine_policy_snapshot( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - return self.engine_policy_arbiter.build_policy_snapshot(request_registry, worker_state) - - async def _wait_for_engine_policy_admission( - self, - *, - request_id: str | None, - timeout_sec: float | None, - ) -> tuple[float, Dict[str, Any]]: - return await self.engine_policy_arbiter.wait_for_policy_admission( - request_id=request_id, - timeout_sec=timeout_sec, - ) - - def _build_stage_summary( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - counters = self._build_stage_counters(request_registry, worker_state) - bert_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_bert_batch_worker", None)) - ref_semantic_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_ref_semantic_batch_worker", None)) - text_preprocessor_state = self._safe_component_snapshot(getattr(self.tts, "text_preprocessor", None)) - - return { - **counters, - "engine_drained": bool(self._is_engine_drained()), - "admission_config": { - "decode_backlog_max": int(worker_state.get("decode_backlog_max", 0)), - "finalize_pending_max": int(worker_state.get("finalize_pending_max", 0)), - }, - "engine_policy": self._build_engine_policy_snapshot(request_registry, worker_state), - "engine_arbiter_state": self._snapshot_engine_arbiter_state(), - "engine_decode_runtime_state": self._snapshot_engine_decode_runtime_state(), - "engine_job_registry": self._snapshot_engine_job_registry(), - "engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(), - "engine_prepare_state": self._snapshot_engine_prepare_state(), - "engine_finalize_state": self._snapshot_engine_finalize_state(), - "engine_dispatcher_state": self._snapshot_engine_dispatch_state(), - "active_batch": dict(worker_state.get("active_batch") or {}), - "prepare_state": dict(worker_state.get("prepare_state") or {}), - "bert_batch_worker_state": bert_worker_state, - "ref_semantic_worker_state": ref_semantic_worker_state, - "text_preprocessor_state": text_preprocessor_state, - } - - def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: - return self.request_registry.collect_summaries(request_ids) - - def _has_active_request(self, request_id: str) -> bool: - return self.request_registry.has_active(request_id) - - @staticmethod - def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: - text = payload.get("text") - prompt_text = payload.get("prompt_text") - return { - "text_len": 0 if text is None else len(str(text)), - "prompt_text_len": 0 if prompt_text is None else len(str(prompt_text)), - "text_lang": payload.get("text_lang"), - "prompt_lang": payload.get("prompt_lang"), - "ref_audio_path": payload.get("ref_audio_path"), - } - - @staticmethod - def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: - total = 0.0 - for item in items: - value = item.get(key, 0.0) - if isinstance(value, (int, float)): - total += float(value) - return total - - def _build_direct_segment_trace( - self, - segment_texts: Sequence[str], - prepare_profiles: Sequence[Dict[str, Any]], - worker_profiles: Sequence[Dict[str, Any]], - ) -> List[Dict[str, Any]]: - results: List[Dict[str, Any]] = [] - for index, segment_text in enumerate(segment_texts): - prepare_item = prepare_profiles[index] if index < len(prepare_profiles) else {} - worker_item = worker_profiles[index] if index < len(worker_profiles) else {} - prepare_profile = dict(prepare_item.get("prepare_profile", {})) - results.append( - { - "segment_index": index, - "request_id": prepare_item.get("request_id") or worker_item.get("request_id"), - "text_len": len(str(segment_text)), - "prepare_wall_ms": float(prepare_item.get("prepare_wall_ms", 0.0)), - "prepare_profile_total_ms": float(prepare_item.get("prepare_profile_total_ms", 0.0)), - "prepare_engine_gpu_queue_wait_ms": float( - dict(prepare_item.get("prepare_profile", {})).get("engine_gpu_prepare_queue_wait_ms", 0.0) - ), - "engine_policy_wait_ms": float(prepare_item.get("engine_policy_wait_ms", 0.0)), - "engine_dispatch_wait_ms": float(prepare_item.get("engine_dispatch_wait_ms", 0.0)), - "decode_admission_wait_ms": float(worker_item.get("decode_admission_wait_ms", 0.0)), - "queue_wait_ms": float(worker_item.get("queue_wait_ms", 0.0)), - "prefill_ms": float(worker_item.get("prefill_ms", 0.0)), - "merge_ms": float(worker_item.get("merge_ms", 0.0)), - "decode_ms": float(worker_item.get("decode_ms", 0.0)), - "finalize_wait_ms": float(worker_item.get("finalize_wait_ms", 0.0)), - "synth_ms": float(worker_item.get("synth_ms", 0.0)), - "worker_total_ms": float(worker_item.get("worker_total_ms", 0.0)), - "decode_steps": int(worker_item.get("decode_steps", 0)), - "semantic_len": int(worker_item.get("semantic_len", 0)), - "finish_reason": worker_item.get("finish_reason"), - "norm_text": prepare_profile.get("norm_text"), - } - ) - return results - - def _build_direct_scheduler_profile( - self, - *, - backend: str, - request_start: float, - response_ready_at: float, - audio_bytes: int, - sample_rate: int, - segment_texts: Sequence[str], - prepare_profiles: Sequence[Dict[str, Any]], - worker_profiles: Sequence[Dict[str, Any]], - pack_ms: float, - response_overhead_ms: float, - ) -> Dict[str, Any]: - segment_trace = self._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) - prepare_profile_dicts = [dict(item.get("prepare_profile", {})) for item in prepare_profiles] - request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) - prepare_wall_ms = self._sum_profile_field(prepare_profiles, "prepare_wall_ms") - prepare_profile_total_ms = self._sum_profile_field(prepare_profiles, "prepare_profile_total_ms") - engine_policy_wait_ms = self._sum_profile_field(prepare_profiles, "engine_policy_wait_ms") - engine_dispatch_wait_ms = self._sum_profile_field(prepare_profiles, "engine_dispatch_wait_ms") - decode_admission_wait_ms = self._sum_profile_field(worker_profiles, "decode_admission_wait_ms") - queue_wait_ms = self._sum_profile_field(worker_profiles, "queue_wait_ms") - prefill_ms = self._sum_profile_field(worker_profiles, "prefill_ms") - merge_ms = self._sum_profile_field(worker_profiles, "merge_ms") - decode_ms = self._sum_profile_field(worker_profiles, "decode_ms") - finalize_wait_ms = self._sum_profile_field(worker_profiles, "finalize_wait_ms") - synth_ms = self._sum_profile_field(worker_profiles, "synth_ms") - worker_total_ms = self._sum_profile_field(worker_profiles, "worker_total_ms") - decode_steps = sum(int(item.get("decode_steps", 0)) for item in worker_profiles) - semantic_len = sum(int(item.get("semantic_len", 0)) for item in worker_profiles) - request_other_ms = max( - 0.0, - request_total_ms - prepare_wall_ms - engine_policy_wait_ms - worker_total_ms - pack_ms - response_overhead_ms, - ) - return { - "backend": backend, - "backend_mode": backend, - "segment_count": len(segment_texts), - "sample_rate": int(sample_rate), - "audio_bytes": int(audio_bytes), - "request_total_ms": request_total_ms, - "prepare_ms": prepare_wall_ms, - "prepare_wall_ms": prepare_wall_ms, - "prepare_profile_total_ms": prepare_profile_total_ms, - "engine_policy_wait_ms": engine_policy_wait_ms, - "engine_dispatch_wait_ms": engine_dispatch_wait_ms, - "decode_admission_wait_ms": decode_admission_wait_ms, - "queue_wait_ms": queue_wait_ms, - "prefill_ms": prefill_ms, - "merge_ms": merge_ms, - "decode_ms": decode_ms, - "finalize_wait_ms": finalize_wait_ms, - "synth_ms": synth_ms, - "pack_ms": pack_ms, - "response_overhead_ms": response_overhead_ms, - "worker_total_ms": worker_total_ms, - "request_other_ms": request_other_ms, - "decode_steps": decode_steps, - "semantic_len": semantic_len, - "prepare_segments": list(prepare_profiles), - "worker_segments": list(worker_profiles), - "segment_trace": segment_trace, - "prepare_aggregate": self._aggregate_numeric_dicts(prepare_profile_dicts), - } - - def _build_legacy_direct_profile( - self, - *, - backend: str, - fallback_reason: str | None, - request_start: float, - finished_at: float, - sample_rate: int | None = None, - audio_bytes: int = 0, - pack_ms: float = 0.0, - chunk_count: int = 0, - stream_total_bytes: int = 0, - first_chunk_ms: float | None = None, - ) -> Dict[str, Any]: - request_total_ms = max(0.0, (finished_at - request_start) * 1000.0) - legacy_infer_ms = max(0.0, request_total_ms - pack_ms) - return { - "backend": backend, - "backend_mode": backend, - "fallback_reason": fallback_reason, - "request_total_ms": request_total_ms, - "prepare_ms": 0.0, - "queue_wait_ms": 0.0, - "prefill_ms": 0.0, - "merge_ms": 0.0, - "decode_ms": 0.0, - "finalize_wait_ms": 0.0, - "synth_ms": 0.0, - "pack_ms": pack_ms, - "worker_total_ms": legacy_infer_ms, - "request_other_ms": 0.0, - "legacy_infer_ms": legacy_infer_ms, - "sample_rate": int(sample_rate) if sample_rate is not None else None, - "audio_bytes": int(audio_bytes), - "chunk_count": int(chunk_count), - "stream_total_bytes": int(stream_total_bytes), - "first_chunk_ms": None if first_chunk_ms is None else float(first_chunk_ms), - } - - def _build_scheduler_submit_profile( - self, - *, - backend: str, - request_start: float, - response_ready_at: float, - audio_bytes: int, - sample_rate: int, - prepare_spec_build_ms: float, - prepare_wall_ms: float, - prepare_executor_queue_ms: float, - prepare_executor_run_ms: float, - prepare_profile_total_ms: float, - prepare_profile_wall_ms: float, - prepare_other_ms: float, - engine_policy_wait_ms: float, - api_after_prepare_ms: float, - api_wait_result_ms: float, - pack_ms: float, - response_overhead_ms: float, - worker_profile: Dict[str, Any], - ) -> Dict[str, Any]: - worker_total_ms = float(worker_profile.get("worker_total_ms", 0.0)) - request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) - request_other_ms = max( - 0.0, - request_total_ms - - prepare_wall_ms - - engine_policy_wait_ms - - api_after_prepare_ms - - worker_total_ms - - api_wait_result_ms - - pack_ms, - ) - result = { - "backend": backend, - "backend_mode": backend, - "audio_bytes": int(audio_bytes), - "sample_rate": int(sample_rate), - "prepare_spec_build_ms": prepare_spec_build_ms, - "prepare_ms": prepare_wall_ms, - "prepare_wall_ms": prepare_wall_ms, - "prepare_executor_queue_ms": prepare_executor_queue_ms, - "prepare_executor_run_ms": prepare_executor_run_ms, - "prepare_profile_total_ms": prepare_profile_total_ms, - "prepare_profile_wall_ms": prepare_profile_wall_ms, - "prepare_other_ms": prepare_other_ms, - "engine_policy_wait_ms": float(engine_policy_wait_ms), - "api_after_prepare_ms": api_after_prepare_ms, - "api_wait_result_ms": api_wait_result_ms, - "pack_ms": pack_ms, - "response_overhead_ms": response_overhead_ms, - "request_total_ms": request_total_ms, - "request_other_ms": request_other_ms, - } - result.update({key: value for key, value in worker_profile.items()}) - return result - - @staticmethod - def _format_ms_header(value: Any) -> str: - return f"{float(value):.3f}" - - def _build_scheduler_submit_headers( - self, - *, - request_id: str, - media_type: str, - sample_rate: int, - profile: Dict[str, Any], - ) -> Dict[str, str]: - prepare_profile = dict(profile.get("prepare_profile", {})) - headers = { - "X-Request-Id": request_id, - "X-Semantic-Len": str(int(profile.get("semantic_len", 0))), - "X-Finish-Reason": str(profile.get("finish_reason", "unknown")), - "X-Queue-Wait-Ms": self._format_ms_header(profile.get("queue_wait_ms", 0.0)), - "X-Decode-Admission-Wait-Ms": self._format_ms_header(profile.get("decode_admission_wait_ms", 0.0)), - "X-Engine-Policy-Wait-Ms": self._format_ms_header(profile.get("engine_policy_wait_ms", 0.0)), - "X-Engine-Dispatch-Wait-Ms": self._format_ms_header(profile.get("engine_dispatch_wait_ms", 0.0)), - "X-Prepare-Ms": self._format_ms_header(profile.get("prepare_wall_ms", 0.0)), - "X-Prepare-Wall-Ms": self._format_ms_header(profile.get("prepare_wall_ms", 0.0)), - "X-Prepare-Spec-Build-Ms": self._format_ms_header(profile.get("prepare_spec_build_ms", 0.0)), - "X-Prepare-Executor-Queue-Ms": self._format_ms_header(profile.get("prepare_executor_queue_ms", 0.0)), - "X-Prepare-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("prepare_admission_wait_ms", 0.0)), - "X-Prepare-Executor-Run-Ms": self._format_ms_header(profile.get("prepare_executor_run_ms", 0.0)), - "X-Prepare-Profile-Total-Ms": self._format_ms_header(profile.get("prepare_profile_total_ms", 0.0)), - "X-Prepare-Profile-Wall-Ms": self._format_ms_header(profile.get("prepare_profile_wall_ms", 0.0)), - "X-Prepare-Other-Ms": self._format_ms_header(profile.get("prepare_other_ms", 0.0)), - "X-Api-After-Prepare-Ms": self._format_ms_header(profile.get("api_after_prepare_ms", 0.0)), - "X-Prefill-Ms": self._format_ms_header(profile.get("prefill_ms", 0.0)), - "X-Merge-Ms": self._format_ms_header(profile.get("merge_ms", 0.0)), - "X-Decode-Ms": self._format_ms_header(profile.get("decode_ms", 0.0)), - "X-Finalize-Wait-Ms": self._format_ms_header(profile.get("finalize_wait_ms", 0.0)), - "X-Synth-Ms": self._format_ms_header(profile.get("synth_ms", 0.0)), - "X-Worker-Residual-Ms": self._format_ms_header(profile.get("worker_residual_ms", 0.0)), - "X-Worker-Other-Ms": self._format_ms_header(profile.get("worker_other_ms", 0.0)), - "X-Pack-Ms": self._format_ms_header(profile.get("pack_ms", 0.0)), - "X-Worker-Total-Ms": self._format_ms_header(profile.get("worker_total_ms", 0.0)), - "X-Api-Wait-Result-Ms": self._format_ms_header(profile.get("api_wait_result_ms", 0.0)), - "X-Decode-Steps": str(int(profile.get("decode_steps", 0))), - "X-Sample-Rate": str(int(sample_rate)), - "X-Response-Overhead-Ms": self._format_ms_header(profile.get("response_overhead_ms", 0.0)), - "X-Request-Other-Ms": self._format_ms_header(profile.get("request_other_ms", 0.0)), - "X-Request-Total-Ms": self._format_ms_header(profile.get("request_total_ms", 0.0)), - } - headers.update( - { - "X-Prepare-Prompt-Text-Ms": self._format_ms_header(prepare_profile.get("prompt_text_features_ms", 0.0)), - "X-Prepare-Target-Text-Ms": self._format_ms_header(prepare_profile.get("text_features_ms", 0.0)), - "X-Prepare-Prompt-Text-CPU-Preprocess-Ms": self._format_ms_header(prepare_profile.get("prompt_text_cpu_preprocess_ms", 0.0)), - "X-Prepare-Target-Text-CPU-Preprocess-Ms": self._format_ms_header(prepare_profile.get("text_cpu_preprocess_ms", 0.0)), - "X-Prepare-Prompt-Text-CPU-Queue-Ms": self._format_ms_header(prepare_profile.get("prompt_text_cpu_queue_ms", 0.0)), - "X-Prepare-Target-Text-CPU-Queue-Ms": self._format_ms_header(prepare_profile.get("text_cpu_queue_ms", 0.0)), - "X-Prepare-Prompt-Text-Feature-Queue-Ms": self._format_ms_header(prepare_profile.get("prompt_text_feature_queue_ms", 0.0)), - "X-Prepare-Target-Text-Feature-Queue-Ms": self._format_ms_header(prepare_profile.get("text_feature_queue_ms", 0.0)), - "X-Prepare-Prompt-Bert-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_wait_ms", 0.0)), - "X-Prepare-Target-Bert-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_wait_ms", 0.0)), - "X-Prepare-Prompt-Bert-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_admission_wait_ms", 0.0)), - "X-Prepare-Target-Bert-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_admission_wait_ms", 0.0)), - "X-Prepare-Prompt-Bert-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_queue_wait_ms", 0.0)), - "X-Prepare-Target-Bert-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_queue_wait_ms", 0.0)), - "X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_batch_collect_wait_ms", 0.0)), - "X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_batch_collect_wait_ms", 0.0)), - "X-Prepare-Prompt-Bert-Forward-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_forward_ms", 0.0)), - "X-Prepare-Target-Bert-Forward-Ms": self._format_ms_header(prepare_profile.get("text_bert_forward_ms", 0.0)), - "X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))), - "X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))), - "X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))), - "X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))), - "X-Prepare-Prompt-Bert-High-Pressure-Peak": str(int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))), - "X-Prepare-Target-Bert-High-Pressure-Peak": str(int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))), - "X-Prepare-Prompt-Bert-Batch-Window-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_batch_window_ms", 0.0)), - "X-Prepare-Target-Bert-Batch-Window-Ms": self._format_ms_header(prepare_profile.get("text_bert_batch_window_ms", 0.0)), - "X-Prepare-Text-Pair-Wall-Ms": self._format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)), - "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), - "X-Prepare-Engine-GPU-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)), - "X-Prepare-Audio-Load-Ms": self._format_ms_header(prepare_profile.get("audio_load_ms", 0.0)), - "X-Prepare-Audio-Stage-Wait-Ms": self._format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)), - "X-Prepare-Prompt-Semantic-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)), - "X-Prepare-Prompt-Semantic-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_wait_ms", 0.0)), - "X-Prepare-Prompt-Semantic-CPU-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), - "X-Prepare-Prompt-Semantic-Forward-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_forward_ms", 0.0)), - "X-Prepare-Ref-Spec-Ms": self._format_ms_header(prepare_profile.get("ref_spec_ms", 0.0)), - "X-Prepare-Ref-Spec-Wait-Ms": self._format_ms_header(prepare_profile.get("ref_spec_wait_ms", 0.0)), - "X-Prepare-Ref-Bundle-Ms": self._format_ms_header(prepare_profile.get("ref_audio_bundle_ms", 0.0)), - "X-Prepare-Tensorize-Ms": self._format_ms_header(prepare_profile.get("tensorize_ms", 0.0)), - "X-Prepare-Inflight-On-Enter": str(int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0))), - "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), - } - ) - return headers - - def _build_scheduler_debug_request_profile( - self, - *, - state: T2SRequestState, - item: T2SFinishedItem, - batch_request_count: int, - prepare_batch_wall_ms: float, - decode_batch_wall_ms: float, - batch_request_total_ms: float, - ) -> Dict[str, Any]: - prepare_profile = dict(state.prepare_profile) - prepare_wall_ms = float(prepare_profile.get("wall_total_ms", 0.0)) - return { - "backend": "scheduler_debug", - "backend_mode": "scheduler_debug", - "batch_request_count": int(batch_request_count), - "batch_prepare_wall_ms": float(prepare_batch_wall_ms), - "batch_decode_wall_ms": float(decode_batch_wall_ms), - "batch_request_total_ms": float(batch_request_total_ms), - "prepare_ms": prepare_wall_ms, - "prepare_wall_ms": prepare_wall_ms, - "prepare_profile_total_ms": float(prepare_profile.get("wall_total_ms", prepare_wall_ms)), - "prepare_profile": prepare_profile, - "decode_steps": int(item.finish_idx), - "finish_idx": int(item.finish_idx), - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_reason": item.finish_reason, - "norm_text": state.norm_text, - "norm_prompt_text": state.norm_prompt_text, - } - - @staticmethod - def _build_scheduler_debug_batch_profile( - *, - request_count: int, - max_steps: int, - prepare_batch_wall_ms: float, - decode_batch_wall_ms: float, - request_total_ms: float, - finished_items: Sequence[T2SFinishedItem], - ) -> Dict[str, Any]: - finish_reason_counts: Dict[str, int] = {} - total_semantic_len = 0 - for item in finished_items: - finish_reason_counts[item.finish_reason] = finish_reason_counts.get(item.finish_reason, 0) + 1 - total_semantic_len += int(item.semantic_tokens.shape[0]) - return { - "request_count": int(request_count), - "max_steps": int(max_steps), - "prepare_batch_wall_ms": float(prepare_batch_wall_ms), - "decode_batch_wall_ms": float(decode_batch_wall_ms), - "request_total_ms": float(request_total_ms), - "total_semantic_len": int(total_semantic_len), - "finish_reason_counts": finish_reason_counts, - } - - def _normalize_lang(self, value: str | None) -> str | None: - if value in [None, ""]: - return value - return str(value).lower() - - @staticmethod - def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: - totals: Dict[str, float] = {} - for item in items: - for key, value in item.items(): - if isinstance(value, (int, float)): - totals[key] = totals.get(key, 0.0) + float(value) - return totals - - def _apply_default_reference(self, req: dict) -> dict: - normalized = dict(req) - default_ref = self.reference_registry.get_default() - if normalized.get("ref_audio_path") in [None, ""] and default_ref.ref_audio_path not in [None, ""]: - normalized["ref_audio_path"] = default_ref.ref_audio_path - if "text_lang" in normalized: - normalized["text_lang"] = self._normalize_lang(normalized.get("text_lang")) - if "prompt_lang" in normalized: - normalized["prompt_lang"] = self._normalize_lang(normalized.get("prompt_lang")) - return normalized - - def check_params(self, req: dict) -> Optional[str]: - text = req.get("text", "") - text_lang = req.get("text_lang", "") - ref_audio_path = req.get("ref_audio_path", "") - media_type = req.get("media_type", "wav") - prompt_lang = req.get("prompt_lang", "") - text_split_method = req.get("text_split_method", "cut5") - - if ref_audio_path in [None, ""]: - return "ref_audio_path is required" - if text in [None, ""]: - return "text is required" - if text_lang in [None, ""]: - return "text_lang is required" - if text_lang.lower() not in self.tts.configs.languages: - return f"text_lang: {text_lang} is not supported in version {self.tts.configs.version}" - if prompt_lang in [None, ""]: - return "prompt_lang is required" - if prompt_lang.lower() not in self.tts.configs.languages: - return f"prompt_lang: {prompt_lang} is not supported in version {self.tts.configs.version}" - if media_type not in ["wav", "raw", "ogg", "aac"]: - return f"media_type: {media_type} is not supported" - if text_split_method not in self.cut_method_names: - return f"text_split_method:{text_split_method} is not supported" - return None - - @staticmethod - def _base_request_defaults() -> Dict[str, Any]: - return { - "request_id": None, - "text": None, - "text_lang": None, - "ref_audio_path": None, - "aux_ref_audio_paths": None, - "prompt_text": "", - "prompt_lang": None, - "top_k": 15, - "top_p": 1.0, - "temperature": 1.0, - "text_split_method": "cut5", - "batch_size": 1, - "batch_threshold": 0.75, - "speed_factor": 1.0, - "split_bucket": False, - "fragment_interval": 0.3, - "seed": -1, - "media_type": "wav", - "streaming_mode": False, - "return_fragment": False, - "fixed_length_chunk": False, - "response_streaming": False, - "parallel_infer": False, - "repetition_penalty": 1.35, - "sample_steps": 32, - "super_sampling": False, - "overlap_length": 2, - "min_chunk_length": 16, - "early_stop_num": -1, - "ready_step": 0, - "timeout_sec": None, - } - - def _normalize_engine_request( - self, - payload: dict | NormalizedEngineRequest, - *, - request_id: str | None = None, - normalize_streaming: bool = False, - error_prefix: str = "request 参数非法: ", - ) -> NormalizedEngineRequest: - if isinstance(payload, NormalizedEngineRequest): - normalized_payload = payload.to_payload() - else: - normalized_payload = self._base_request_defaults() - normalized_payload.update(dict(payload)) - if request_id not in [None, ""]: - normalized_payload["request_id"] = str(request_id) - elif normalized_payload.get("request_id") in [None, ""]: - raise ValueError("request_id is required after normalization") - normalized_payload = self._apply_default_reference(normalized_payload) - if normalize_streaming: - normalized_payload = self._normalize_streaming_mode(normalized_payload) - error = self.check_params(normalized_payload) - if error is not None: - raise ValueError(f"{error_prefix}{error}") - timeout_sec = normalized_payload.get("timeout_sec") - if timeout_sec in [None, ""]: - parsed_timeout = None - else: - parsed_timeout = float(timeout_sec) - aux_ref_audio_paths = normalized_payload.get("aux_ref_audio_paths") - if aux_ref_audio_paths in [None, "", []]: - normalized_aux_ref_audio_paths = None - else: - normalized_aux_ref_audio_paths = [str(item) for item in aux_ref_audio_paths] - return NormalizedEngineRequest( - request_id=str(normalized_payload["request_id"]), - text=str(normalized_payload["text"]), - text_lang=str(normalized_payload["text_lang"]), - ref_audio_path=str(normalized_payload["ref_audio_path"]), - prompt_lang=str(normalized_payload["prompt_lang"]), - prompt_text="" if normalized_payload.get("prompt_text") is None else str(normalized_payload.get("prompt_text")), - aux_ref_audio_paths=normalized_aux_ref_audio_paths, - top_k=int(normalized_payload["top_k"]), - top_p=float(normalized_payload["top_p"]), - temperature=float(normalized_payload["temperature"]), - repetition_penalty=float(normalized_payload["repetition_penalty"]), - early_stop_num=int(normalized_payload.get("early_stop_num", -1)), - ready_step=int(normalized_payload.get("ready_step", 0)), - text_split_method=str(normalized_payload["text_split_method"]), - batch_size=int(normalized_payload["batch_size"]), - batch_threshold=float(normalized_payload["batch_threshold"]), - split_bucket=bool(normalized_payload["split_bucket"]), - speed_factor=float(normalized_payload["speed_factor"]), - fragment_interval=float(normalized_payload["fragment_interval"]), - seed=int(normalized_payload["seed"]), - media_type=str(normalized_payload["media_type"]), - streaming_mode=normalized_payload["streaming_mode"], - return_fragment=bool(normalized_payload.get("return_fragment", False)), - fixed_length_chunk=bool(normalized_payload.get("fixed_length_chunk", False)), - response_streaming=bool(normalized_payload.get("response_streaming", False)), - parallel_infer=bool(normalized_payload["parallel_infer"]), - sample_steps=int(normalized_payload["sample_steps"]), - super_sampling=bool(normalized_payload["super_sampling"]), - overlap_length=int(normalized_payload["overlap_length"]), - min_chunk_length=int(normalized_payload["min_chunk_length"]), - timeout_sec=parsed_timeout, - ) - - @staticmethod - def _normalize_streaming_mode(req: dict) -> dict: - normalized = dict(req) - streaming_mode = normalized.get("streaming_mode", False) - return_fragment = normalized.get("return_fragment", False) - if streaming_mode is False: - normalized["streaming_mode"] = False - normalized["return_fragment"] = False - normalized["fixed_length_chunk"] = False - elif streaming_mode == 0: - normalized["streaming_mode"] = False - normalized["return_fragment"] = False - normalized["fixed_length_chunk"] = False - elif streaming_mode == 1 or streaming_mode is True: - normalized["streaming_mode"] = False - normalized["return_fragment"] = True - normalized["fixed_length_chunk"] = False - elif streaming_mode == 2: - normalized["streaming_mode"] = True - normalized["return_fragment"] = False - normalized["fixed_length_chunk"] = False - elif streaming_mode == 3: - normalized["streaming_mode"] = True - normalized["return_fragment"] = False - normalized["fixed_length_chunk"] = True - else: - raise ValueError("the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)") - normalized["response_streaming"] = bool(normalized["streaming_mode"] or normalized["return_fragment"] or return_fragment) - return normalized - - @staticmethod - def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: - return aux_ref_audio_paths not in [None, [], ()] - - def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: - if normalized.response_streaming: - if normalized.return_fragment or normalized.fixed_length_chunk: - return "legacy_direct_fragment", "fragment_streaming_mode" - return "legacy_direct_streaming", "streaming_mode" - if self._is_aux_ref_enabled(normalized.aux_ref_audio_paths): - return "legacy_direct_aux_ref", "aux_ref_audio_paths" - if normalized.super_sampling: - return "legacy_direct_super_sampling", "super_sampling" - if normalized.prompt_text in [None, ""]: - return "legacy_direct_missing_prompt", "missing_prompt_text" - return "scheduler_v1_direct", None - - def _iter_legacy_direct_tts_bytes( - self, - normalized: NormalizedEngineRequest, - *, - backend: str, - fallback_reason: str | None, - ) -> Generator[bytes, None, None]: - payload = normalized.to_payload() - media_type = normalized.media_type - request_id = normalized.request_id - request_start = time.perf_counter() - chunk_count = 0 - stream_total_bytes = 0 - first_chunk_ms: float | None = None - self._update_request_state( - request_id, - EngineStatus.ACTIVE_DECODE, - {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, - ) - try: - with self.direct_tts_lock: - tts_generator = self.tts.run(payload) - first_chunk = True - current_media_type = media_type - for sr, chunk in tts_generator: - if first_chunk: - first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) - self._update_request_state( - request_id, - EngineStatus.STREAMING, - { - "backend": backend, - "backend_mode": backend, - "fallback_reason": fallback_reason, - "sample_rate": int(sr), - }, - ) - if first_chunk and media_type == "wav": - header = wave_header_chunk(sample_rate=sr) - chunk_count += 1 - stream_total_bytes += len(header) - yield header - current_media_type = "raw" - first_chunk = False - elif first_chunk: - first_chunk = False - packed_chunk = pack_audio(BytesIO(), chunk, sr, current_media_type).getvalue() - chunk_count += 1 - stream_total_bytes += len(packed_chunk) - yield packed_chunk - except Exception as exc: - self._fail_request_state(request_id, str(exc)) - raise - self._complete_request_state( - request_id, - dict( - self._build_legacy_direct_profile( - backend=backend, - fallback_reason=fallback_reason, - request_start=request_start, - finished_at=time.perf_counter(), - audio_bytes=stream_total_bytes, - chunk_count=chunk_count, - stream_total_bytes=stream_total_bytes, - first_chunk_ms=first_chunk_ms, - ), - streaming_completed=True, - ), - ) - - def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: - if isinstance(req, NormalizedEngineRequest): - normalized = req - else: - normalized = self._normalize_engine_request( - req, - request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), - normalize_streaming=True, - ) - backend, _ = self._select_direct_backend(normalized) - return backend == "scheduler_v1_direct" - - def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: - payload = normalized.to_payload() if isinstance(normalized, NormalizedEngineRequest) else normalized - return self.tts.text_preprocessor.pre_seg_text( - str(payload["text"]), - str(payload["text_lang"]), - str(payload.get("text_split_method", "cut5")), - ) - - def _build_segment_request( - self, - normalized: NormalizedEngineRequest, - *, - request_id: str, - text: str, - ) -> NormalizedEngineRequest: - payload = normalized.to_payload() - payload["request_id"] = request_id - payload["text"] = text - payload["streaming_mode"] = False - payload["return_fragment"] = False - payload["fixed_length_chunk"] = False - payload["response_streaming"] = False - return self._normalize_engine_request(payload, error_prefix="segment request 参数非法: ") - - async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: - request_start = time.perf_counter() - request_id = normalized.request_id - media_type = normalized.media_type - segment_texts = self._segment_direct_text(normalized) - if not segment_texts: - raise ValueError("text preprocessing returned no valid segments") - self._update_request_state( - request_id, - EngineStatus.CPU_PREPARING, - {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)}, - ) - segment_specs: List[SchedulerRequestSpec] = [] - for segment_index, segment_text in enumerate(segment_texts): - segment_request = self._build_segment_request( - normalized, - request_id=f"{request_id}_seg_{segment_index:03d}", - text=segment_text, - ) - segment_specs.append(self.build_scheduler_submit_spec(segment_request)) - - prepared_items = await asyncio.gather( - *[ - self._prepare_state_via_engine_gpu_queue( - spec=spec, - prepare_submit_at=time.perf_counter(), - engine_request_id=None, - ) - for spec in segment_specs - ] - ) - prepare_profiles: List[Dict[str, Any]] = [] - loop = asyncio.get_running_loop() - done_futures: List[asyncio.Future] = [] - self._update_request_state( - request_id, - EngineStatus.READY_FOR_PREFILL, - {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_specs)}, - ) - for spec, (state, prepare_exec_started_at, prepare_exec_finished_at) in zip(segment_specs, prepared_items): - prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) - prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms)) - prepare_profiles.append( - { - "request_id": spec.request_id, - "prepare_wall_ms": prepare_wall_ms, - "prepare_profile_total_ms": prepare_profile_total_ms, - "prepare_profile": dict(state.prepare_profile), - } - ) - done_future = loop.create_future() - done_futures.append(done_future) - await self._enqueue_prepared_state_for_dispatch( - state=state, - speed_factor=float(normalized.speed_factor), - sample_steps=int(normalized.sample_steps), - media_type=media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=loop, - done_future=done_future, - engine_request_id=None, - timeout_sec=normalized.timeout_sec, - ) - self._update_request_state( - request_id, - EngineStatus.ACTIVE_DECODE, - {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, - ) - timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0) - jobs: List[SchedulerPendingJob] = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=timeout_sec)) - for profile_item, job in zip(prepare_profiles, jobs): - profile_item["engine_policy_wait_ms"] = float(job.engine_policy_wait_ms) - profile_item["engine_dispatch_wait_ms"] = float(job.engine_dispatch_wait_ms) - self._merge_request_state_profile( - request_id, - { - "engine_policy_wait_ms": sum(float(job.engine_policy_wait_ms) for job in jobs), - "engine_dispatch_wait_ms": sum(float(job.engine_dispatch_wait_ms) for job in jobs), - "prepare_aggregate": self._aggregate_numeric_dicts( - [item["prepare_profile"] for item in prepare_profiles] - ), - }, - ) - - sample_rate: int | None = None - audio_parts: List[np.ndarray] = [] - worker_profiles: List[Dict[str, Any]] = [] - fragment_interval = float(normalized.fragment_interval) - silence_chunk: Optional[np.ndarray] = None - for job in jobs: - if job.error is not None: - raise RuntimeError(job.error) - if job.audio_data is None or job.sample_rate is None or job.result is None: - raise RuntimeError(f"{job.request_id} finished without audio result") - if sample_rate is None: - sample_rate = int(job.sample_rate) - silence_samples = int(fragment_interval * float(sample_rate)) - if silence_samples > 0: - silence_chunk = np.zeros(silence_samples, dtype=np.int16) - elif int(job.sample_rate) != sample_rate: - raise RuntimeError("segment sample rate mismatch") - audio_parts.append(job.audio_data) - if silence_chunk is not None: - audio_parts.append(silence_chunk.copy()) - worker_profiles.append(dict(job.result)) - if sample_rate is None or not audio_parts: - raise RuntimeError("direct scheduler backend produced no audio") - self._update_request_state( - request_id, - EngineStatus.FINALIZING, - {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, - ) - merged_audio = np.concatenate(audio_parts, axis=0) - pack_start = time.perf_counter() - audio_bytes = pack_audio(BytesIO(), merged_audio, sample_rate, media_type).getvalue() - pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0) - direct_profile = self._build_direct_scheduler_profile( - backend="scheduler_v1_direct", - request_start=request_start, - response_ready_at=time.perf_counter(), - audio_bytes=len(audio_bytes), - sample_rate=int(sample_rate), - segment_texts=segment_texts, - prepare_profiles=prepare_profiles, - worker_profiles=worker_profiles, - pack_ms=pack_ms, - response_overhead_ms=0.0, - ) - self._complete_request_state( - request_id, - dict(direct_profile, streaming_completed=False), - ) - return DirectTTSExecution( - media_type=media_type, - streaming=False, - audio_bytes=audio_bytes, - request_id=request_id, - ) - - def _run_legacy_direct_tts_blocking( - self, - normalized: NormalizedEngineRequest, - *, - backend: str, - fallback_reason: str | None, - ) -> DirectTTSExecution: - normalized_payload = normalized.to_payload() - request_id = normalized.request_id - media_type = normalized.media_type - request_start = time.perf_counter() - self._update_request_state( - request_id, - EngineStatus.ACTIVE_DECODE, - {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, - ) - with self.direct_tts_lock: - tts_generator = self.tts.run(normalized_payload) - try: - sr, audio_data = next(tts_generator) - except Exception as exc: - self._fail_request_state(request_id, str(exc)) - raise - self._update_request_state( - request_id, - EngineStatus.FINALIZING, - {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, - ) - pack_start = time.perf_counter() - packed_audio = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() - pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0) - self._complete_request_state( - request_id, - dict( - self._build_legacy_direct_profile( - backend=backend, - fallback_reason=fallback_reason, - request_start=request_start, - finished_at=time.perf_counter(), - sample_rate=int(sr), - audio_bytes=len(packed_audio), - pack_ms=pack_ms, - ), - streaming_completed=False, - ), - ) - return DirectTTSExecution( - media_type=media_type, - streaming=False, - audio_bytes=packed_audio, - request_id=request_id, - ) - - async def _run_direct_tts_via_legacy_backend( - self, - normalized: NormalizedEngineRequest, - *, - backend: str, - fallback_reason: str | None, - ) -> DirectTTSExecution: - if normalized.response_streaming: - return DirectTTSExecution( - media_type=normalized.media_type, - streaming=True, - audio_generator=self._iter_legacy_direct_tts_bytes( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ), - request_id=normalized.request_id, - ) - return await asyncio.to_thread( - self._run_legacy_direct_tts_blocking, - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) - - async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: - normalized = self._normalize_engine_request( - req, - request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), - normalize_streaming=True, - error_prefix="", - ) - request_id = normalized.request_id - media_type = normalized.media_type - backend, fallback_reason = self._select_direct_backend(normalized) - self._register_request_state( - request_id=request_id, - api_mode="tts", - backend=backend, - media_type=media_type, - response_streaming=bool(normalized.response_streaming), - deadline_ts=( - time.perf_counter() + float(normalized.timeout_sec) - if normalized.timeout_sec is not None - else None - ), - meta=self._build_request_meta(normalized.to_payload()), - ) - self._update_request_state( - request_id, - EngineStatus.VALIDATED, - { - "request_source": "direct_tts", - "selected_backend": backend, - "fallback_reason": fallback_reason, - }, - ) - if backend == "scheduler_v1_direct": - try: - return await self._run_direct_tts_via_scheduler(normalized) - except Exception as exc: - self._fail_request_state(request_id, str(exc)) - raise - return await self._run_direct_tts_via_legacy_backend( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) - - def run_direct_tts(self, req: dict) -> DirectTTSExecution: - normalized = self._normalize_engine_request( - req, - request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), - normalize_streaming=True, - error_prefix="", - ) - request_id = normalized.request_id - media_type = normalized.media_type - backend, fallback_reason = self._select_direct_backend(normalized) - if not self._has_active_request(request_id): - self._register_request_state( - request_id=request_id, - api_mode="tts", - backend=backend, - media_type=media_type, - response_streaming=bool(normalized.response_streaming), - meta=self._build_request_meta(normalized.to_payload()), - ) - self._update_request_state( - request_id, - EngineStatus.VALIDATED, - { - "request_source": "direct_tts", - "selected_backend": backend, - "fallback_reason": fallback_reason, - }, - ) - if backend != "scheduler_v1_direct": - if normalized.response_streaming: - return DirectTTSExecution( - media_type=media_type, - streaming=True, - audio_generator=self._iter_legacy_direct_tts_bytes( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ), - request_id=request_id, - ) - return self._run_legacy_direct_tts_blocking( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) - normalized_payload = normalized.to_payload() - if normalized.response_streaming: - return DirectTTSExecution( - media_type=media_type, - streaming=True, - audio_generator=self._iter_legacy_direct_tts_bytes( - normalized, - backend="legacy_direct_sync_compat", - fallback_reason="sync_direct_compat", - ), - request_id=request_id, - ) - return self._run_legacy_direct_tts_blocking( - normalized, - backend="legacy_direct_sync_compat", - fallback_reason="sync_direct_compat", - ) - - def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: - specs: List[SchedulerRequestSpec] = [] - for index, payload in enumerate(request_items): - normalized = self._normalize_engine_request( - payload, - request_id=str(payload.get("request_id") or f"req_{index:03d}"), - error_prefix=f"request[{index}] 参数非法: ", - ) - specs.append(normalized.to_scheduler_spec()) - return specs - - def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: - normalized = self._normalize_engine_request( - payload, - request_id=( - payload.request_id - if isinstance(payload, NormalizedEngineRequest) - else str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}") - ), - ) - return normalized.to_scheduler_spec() - - @staticmethod - def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: - return [ - { - "request_id": state.request_id, - "ready_step": int(state.ready_step), - "ref_audio_path": str(state.ref_audio_path), - "prompt_semantic_len": int(state.prompt_semantic.shape[0]), - "all_phone_len": int(state.all_phones.shape[0]), - "bert_len": int(state.all_bert_features.shape[-1]), - "norm_text": state.norm_text, - } - for state in states - ] - - @staticmethod - def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: - return [ - { - "request_id": item.request_id, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "finish_reason": item.finish_reason, - } - for item in items - ] - - async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: - request_start = time.perf_counter() - set_scheduler_seed(seed) - specs = self.build_scheduler_request_specs(request_items) - request_ids = [spec.request_id for spec in specs] - for spec in specs: - self._register_request_state( - request_id=spec.request_id, - api_mode="scheduler_debug", - backend="scheduler_debug", - media_type="wav", - response_streaming=False, - meta={ - "text_len": len(spec.text), - "prompt_text_len": len(spec.prompt_text), - "text_lang": spec.text_lang, - "prompt_lang": spec.prompt_lang, - "ref_audio_path": str(spec.ref_audio_path), - "ready_step": int(spec.ready_step), - }, - ) - self._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"}) - self._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, None) - prepare_started_at = time.perf_counter() - try: - states = await self.scheduler_worker.prepare_states_batch_async(specs) - except Exception as exc: - for request_id in request_ids: - self._fail_request_state(request_id, str(exc)) - raise - prepare_finished_at = time.perf_counter() - prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0) - for state in states: - self._update_request_state( - state.request_id, - EngineStatus.ACTIVE_DECODE, - { - "prepare_profile": dict(state.prepare_profile), - "norm_text": state.norm_text, - "norm_prompt_text": state.norm_prompt_text, - }, - ) - decode_started_at = time.perf_counter() - try: - finished = run_scheduler_continuous(self.tts.t2s_model.model, states, max_steps=int(max_steps)) - except Exception as exc: - for request_id in request_ids: - self._fail_request_state(request_id, str(exc)) - raise - decode_finished_at = time.perf_counter() - decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0) - request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0) - finished_map = {item.request_id: item for item in finished} - request_profiles: List[Dict[str, Any]] = [] - for state in states: - item = finished_map.get(state.request_id) - if item is None: - self._fail_request_state(state.request_id, "scheduler_debug finished without result") - continue - request_profile = self._build_scheduler_debug_request_profile( - state=state, - item=item, - batch_request_count=len(states), - prepare_batch_wall_ms=prepare_batch_wall_ms, - decode_batch_wall_ms=decode_batch_wall_ms, - batch_request_total_ms=request_total_ms, - ) - request_profiles.append( - { - "request_id": state.request_id, - "profile": dict(request_profile), - } - ) - self._complete_request_state( - state.request_id, - dict(request_profile), - ) - return SchedulerDebugExecution( - payload={ - "message": "success", - "request_count": len(states), - "max_steps": int(max_steps), - "batch_profile": self._build_scheduler_debug_batch_profile( - request_count=len(states), - max_steps=int(max_steps), - prepare_batch_wall_ms=prepare_batch_wall_ms, - decode_batch_wall_ms=decode_batch_wall_ms, - request_total_ms=request_total_ms, - finished_items=finished, - ), - "requests": self.summarize_scheduler_states(states), - "finished": self.summarize_scheduler_finished(finished), - "request_profiles": request_profiles, - "request_traces": self._collect_request_summaries(request_ids), - } - ) - - async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: - request_start = time.perf_counter() - prepare_start = request_start - normalized = self._normalize_engine_request( - payload, - request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"), - ) - spec = self.build_scheduler_submit_spec(normalized) - deadline_ts = None - timeout_sec = normalized.timeout_sec - if timeout_sec is not None: - try: - deadline_ts = request_start + float(timeout_sec) - except Exception: - deadline_ts = None - self._register_request_state( - request_id=spec.request_id, - api_mode="scheduler_submit", - backend="scheduler_v1", - media_type=normalized.media_type, - response_streaming=False, - deadline_ts=deadline_ts, - meta=self._build_request_meta(normalized.to_payload()), - ) - self._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_submit"}) - spec_ready_at = time.perf_counter() - prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) - self._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, {"prepare_spec_build_ms": prepare_spec_build_ms}) - try: - state, prepare_exec_started_at, prepare_exec_finished_at = await self._prepare_state_via_engine_gpu_queue( - spec=spec, - prepare_submit_at=spec_ready_at, - engine_request_id=spec.request_id, - ) - except Exception as exc: - self._fail_request_state(spec.request_id, str(exc)) - raise - prepare_wall_ms = max(0.0, (prepare_exec_finished_at - spec_ready_at) * 1000.0) - prepare_executor_queue_ms = max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0) - prepare_executor_run_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) - prepare_profile = dict(state.prepare_profile) - prepare_profile_total_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) - prepare_profile_wall_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) - prepare_other_ms = max(0.0, prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_executor_run_ms) - self._update_request_state( - spec.request_id, - EngineStatus.READY_FOR_PREFILL, - { - "prepare_wall_ms": prepare_wall_ms, - "prepare_profile_total_ms": prepare_profile_total_ms, - "prepare_profile": prepare_profile, - }, - ) - api_after_prepare_start = time.perf_counter() - loop = asyncio.get_running_loop() - done_future = loop.create_future() - await self._enqueue_prepared_state_for_dispatch( - state=state, - speed_factor=float(normalized.speed_factor), - sample_steps=int(normalized.sample_steps), - media_type=normalized.media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=loop, - done_future=done_future, - engine_request_id=spec.request_id, - timeout_sec=normalized.timeout_sec, - ) - api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0) - try: - job = await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)) - except Exception as exc: - self._fail_request_state(spec.request_id, str(exc)) - raise - wait_return_at = time.perf_counter() - if job.error is not None: - raise RuntimeError(job.error) - if job.audio_data is None or job.sample_rate is None or job.result is None: - self._fail_request_state(spec.request_id, f"{job.request_id} finished without audio result") - raise RuntimeError(f"{job.request_id} finished without audio result") - pack_start = time.perf_counter() - audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() - pack_end = time.perf_counter() - pack_ms = (pack_end - pack_start) * 1000.0 - api_wait_result_ms = 0.0 - if job.result_ready_time is not None: - api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0) - response_ready_at = time.perf_counter() - response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0) - submit_profile = self._build_scheduler_submit_profile( - backend="scheduler_v1", - request_start=request_start, - response_ready_at=response_ready_at, - audio_bytes=len(audio_data), - sample_rate=int(job.sample_rate), - prepare_spec_build_ms=prepare_spec_build_ms, - prepare_wall_ms=prepare_wall_ms, - prepare_executor_queue_ms=prepare_executor_queue_ms, - prepare_executor_run_ms=prepare_executor_run_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - prepare_profile_wall_ms=prepare_profile_wall_ms, - prepare_other_ms=prepare_other_ms, - engine_policy_wait_ms=float(job.result.get("engine_policy_wait_ms", 0.0)), - api_after_prepare_ms=api_after_prepare_ms, - api_wait_result_ms=api_wait_result_ms, - pack_ms=pack_ms, - response_overhead_ms=response_overhead_ms, - worker_profile=dict(job.result or {}), - ) - headers = self._build_scheduler_submit_headers( - request_id=job.request_id, - media_type=job.media_type, - sample_rate=int(job.sample_rate), - profile=submit_profile, - ) - self._merge_request_state_profile( - spec.request_id, - dict(submit_profile, response_headers_emitted=True), - ) - return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=f"audio/{job.media_type}", headers=headers) - - def get_scheduler_state(self) -> dict: - return self.scheduler_worker.snapshot() - - def get_runtime_state(self) -> dict: - model_state = self.model_registry.snapshot() - default_ref = self.reference_registry.get_default() - scheduler_state = self.get_scheduler_state() - request_registry = self._snapshot_request_registry() - engine_policy = self._build_engine_policy_snapshot(request_registry, scheduler_state) - engine_arbiter_state = self._snapshot_engine_arbiter_state() - engine_decode_runtime_state = self._snapshot_engine_decode_runtime_state() - engine_job_registry = self._snapshot_engine_job_registry() - engine_prepare_state = self._snapshot_engine_prepare_state() - engine_finalize_state = self._snapshot_engine_finalize_state() - engine_dispatcher_state = self._snapshot_engine_dispatch_state() - engine_drained = self._is_engine_drained() - return { - "message": "success", - "default_reference": { - "ref_audio_path": default_ref.ref_audio_path, - "updated_at": default_ref.updated_at, - }, - "model_registry": { - "generation": model_state.generation, - "t2s_generation": model_state.t2s_generation, - "vits_generation": model_state.vits_generation, - "t2s_weights_path": model_state.t2s_weights_path, - "vits_weights_path": model_state.vits_weights_path, - "updated_at": model_state.updated_at, - }, - "worker_state": scheduler_state, - "engine_policy": engine_policy, - "engine_arbiter_state": engine_arbiter_state, - "engine_decode_runtime_state": engine_decode_runtime_state, - "engine_job_registry": engine_job_registry, - "engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(), - "engine_prepare_state": engine_prepare_state, - "engine_finalize_state": engine_finalize_state, - "engine_dispatcher_state": engine_dispatcher_state, - "engine_drained": bool(engine_drained), - "request_registry": request_registry, - "stage_summary": self._build_stage_summary(request_registry, scheduler_state), - } - - def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: - if not self.scheduler_worker.wait_until_idle(timeout_sec=timeout_sec): - raise TimeoutError("scheduler worker did not drain before model reload") - - def set_refer_audio(self, refer_audio_path: str | None) -> dict: - if refer_audio_path in [None, ""]: - state = self.reference_registry.clear() - return {"message": "success", "default_ref_audio_path": state.ref_audio_path} - if not os.path.exists(str(refer_audio_path)): - raise FileNotFoundError(f"{refer_audio_path} not exists") - with self.management_lock: - with self.direct_tts_lock: - self.tts.set_ref_audio(str(refer_audio_path)) - state = self.reference_registry.set_default(str(refer_audio_path)) - return {"message": "success", "default_ref_audio_path": state.ref_audio_path} - - def set_gpt_weights(self, weights_path: str) -> dict: - if weights_path in ["", None]: - raise ValueError("gpt weight path is required") - with self.management_lock: - self._wait_for_safe_reload() - with self.direct_tts_lock: - self.tts.init_t2s_weights(weights_path) - self.tts.refresh_runtime_components() - state = self.model_registry.mark_t2s_reload(str(weights_path)) - return {"message": "success", "t2s_generation": state.t2s_generation, "generation": state.generation} - - def set_sovits_weights(self, weights_path: str) -> dict: - if weights_path in ["", None]: - raise ValueError("sovits weight path is required") - with self.management_lock: - self._wait_for_safe_reload() - with self.direct_tts_lock: - self.tts.init_vits_weights(weights_path) - self.tts.refresh_runtime_components() - state = self.model_registry.mark_vits_reload(str(weights_path)) - return {"message": "success", "vits_generation": state.vits_generation, "generation": state.generation} - - def handle_control(self, command: str) -> None: - if command == "restart": - if self.control_callbacks.restart is None: - os.execl(sys.executable, sys.executable, *sys.argv) - self.control_callbacks.restart() - return - if command == "exit": - if self.control_callbacks.exit is None: - os.kill(os.getpid(), signal.SIGTERM) - return - self.control_callbacks.exit() - return - raise ValueError(f"unsupported command: {command}") + EngineCompositionBuilder(self).build(max_steps=max_steps, micro_batch_wait_ms=micro_batch_wait_ms) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py new file mode 100644 index 00000000..bcc8bf0d --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py @@ -0,0 +1,1399 @@ +from __future__ import annotations + +import asyncio +import time +import uuid +from io import BytesIO +from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState, run_scheduler_continuous +from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, set_scheduler_seed, wave_header_chunk +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( + DirectTTSExecution, + EngineStatus, + NormalizedEngineRequest, + SchedulerDebugExecution, + SchedulerPendingJob, + SchedulerSubmitExecution, +) + + +class EngineApiFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def tts(self): + return self.owner.tts + + @property + def cut_method_names(self): + return self.owner.cut_method_names + + @property + def reference_registry(self): + return self.owner.reference_registry + + @property + def direct_tts_lock(self): + return self.owner.direct_tts_lock + + @property + def scheduler_worker(self): + return self.owner.scheduler_worker + + def _register_request_state( + self, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ): + return self.owner._register_request_state( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=response_streaming, + deadline_ts=deadline_ts, + meta=meta, + ) + + def _update_request_state( + self, + request_id: str, + status: str, + extra: Optional[Dict[str, Any]] = None, + ) -> None: + self.owner._update_request_state(request_id, status, extra) + + def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.owner._merge_request_state_profile(request_id, extra) + + def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.owner._complete_request_state(request_id, extra) + + def _fail_request_state(self, request_id: str, error: str) -> None: + self.owner._fail_request_state(request_id, error) + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + return await self.owner._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, + ) + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ): + return await self.owner._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=speed_factor, + sample_steps=sample_steps, + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id, + timeout_sec=timeout_sec, + ) + + def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: + return self.owner.request_registry.collect_summaries(request_ids) + + def _has_active_request(self, request_id: str) -> bool: + return self.owner.request_registry.has_active(request_id) + + @staticmethod + def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: + text = payload.get("text") + prompt_text = payload.get("prompt_text") + return { + "text_len": 0 if text is None else len(str(text)), + "prompt_text_len": 0 if prompt_text is None else len(str(prompt_text)), + "text_lang": payload.get("text_lang"), + "prompt_lang": payload.get("prompt_lang"), + "ref_audio_path": payload.get("ref_audio_path"), + } + + @staticmethod + def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: + total = 0.0 + for item in items: + value = item.get(key, 0.0) + if isinstance(value, (int, float)): + total += float(value) + return total + + def _build_direct_segment_trace( + self, + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + results: List[Dict[str, Any]] = [] + for index, segment_text in enumerate(segment_texts): + prepare_item = prepare_profiles[index] if index < len(prepare_profiles) else {} + worker_item = worker_profiles[index] if index < len(worker_profiles) else {} + prepare_profile = dict(prepare_item.get("prepare_profile", {})) + results.append( + { + "segment_index": index, + "request_id": prepare_item.get("request_id") or worker_item.get("request_id"), + "text_len": len(str(segment_text)), + "prepare_wall_ms": float(prepare_item.get("prepare_wall_ms", 0.0)), + "prepare_profile_total_ms": float(prepare_item.get("prepare_profile_total_ms", 0.0)), + "prepare_engine_gpu_queue_wait_ms": float( + dict(prepare_item.get("prepare_profile", {})).get("engine_gpu_prepare_queue_wait_ms", 0.0) + ), + "engine_policy_wait_ms": float(prepare_item.get("engine_policy_wait_ms", 0.0)), + "engine_dispatch_wait_ms": float(prepare_item.get("engine_dispatch_wait_ms", 0.0)), + "decode_admission_wait_ms": float(worker_item.get("decode_admission_wait_ms", 0.0)), + "queue_wait_ms": float(worker_item.get("queue_wait_ms", 0.0)), + "prefill_ms": float(worker_item.get("prefill_ms", 0.0)), + "merge_ms": float(worker_item.get("merge_ms", 0.0)), + "decode_ms": float(worker_item.get("decode_ms", 0.0)), + "finalize_wait_ms": float(worker_item.get("finalize_wait_ms", 0.0)), + "synth_ms": float(worker_item.get("synth_ms", 0.0)), + "worker_total_ms": float(worker_item.get("worker_total_ms", 0.0)), + "decode_steps": int(worker_item.get("decode_steps", 0)), + "semantic_len": int(worker_item.get("semantic_len", 0)), + "finish_reason": worker_item.get("finish_reason"), + "norm_text": prepare_profile.get("norm_text"), + } + ) + return results + + def _build_direct_scheduler_profile( + self, + *, + backend: str, + request_start: float, + response_ready_at: float, + audio_bytes: int, + sample_rate: int, + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], + pack_ms: float, + response_overhead_ms: float, + ) -> Dict[str, Any]: + segment_trace = self._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) + prepare_profile_dicts = [dict(item.get("prepare_profile", {})) for item in prepare_profiles] + request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) + prepare_wall_ms = self._sum_profile_field(prepare_profiles, "prepare_wall_ms") + prepare_profile_total_ms = self._sum_profile_field(prepare_profiles, "prepare_profile_total_ms") + engine_policy_wait_ms = self._sum_profile_field(prepare_profiles, "engine_policy_wait_ms") + engine_dispatch_wait_ms = self._sum_profile_field(prepare_profiles, "engine_dispatch_wait_ms") + decode_admission_wait_ms = self._sum_profile_field(worker_profiles, "decode_admission_wait_ms") + queue_wait_ms = self._sum_profile_field(worker_profiles, "queue_wait_ms") + prefill_ms = self._sum_profile_field(worker_profiles, "prefill_ms") + merge_ms = self._sum_profile_field(worker_profiles, "merge_ms") + decode_ms = self._sum_profile_field(worker_profiles, "decode_ms") + finalize_wait_ms = self._sum_profile_field(worker_profiles, "finalize_wait_ms") + synth_ms = self._sum_profile_field(worker_profiles, "synth_ms") + worker_total_ms = self._sum_profile_field(worker_profiles, "worker_total_ms") + decode_steps = sum(int(item.get("decode_steps", 0)) for item in worker_profiles) + semantic_len = sum(int(item.get("semantic_len", 0)) for item in worker_profiles) + request_other_ms = max( + 0.0, + request_total_ms - prepare_wall_ms - engine_policy_wait_ms - worker_total_ms - pack_ms - response_overhead_ms, + ) + return { + "backend": backend, + "backend_mode": backend, + "segment_count": len(segment_texts), + "sample_rate": int(sample_rate), + "audio_bytes": int(audio_bytes), + "request_total_ms": request_total_ms, + "prepare_ms": prepare_wall_ms, + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "engine_policy_wait_ms": engine_policy_wait_ms, + "engine_dispatch_wait_ms": engine_dispatch_wait_ms, + "decode_admission_wait_ms": decode_admission_wait_ms, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": prefill_ms, + "merge_ms": merge_ms, + "decode_ms": decode_ms, + "finalize_wait_ms": finalize_wait_ms, + "synth_ms": synth_ms, + "pack_ms": pack_ms, + "response_overhead_ms": response_overhead_ms, + "worker_total_ms": worker_total_ms, + "request_other_ms": request_other_ms, + "decode_steps": decode_steps, + "semantic_len": semantic_len, + "prepare_segments": list(prepare_profiles), + "worker_segments": list(worker_profiles), + "segment_trace": segment_trace, + "prepare_aggregate": self._aggregate_numeric_dicts(prepare_profile_dicts), + } + + def _build_legacy_direct_profile( + self, + *, + backend: str, + fallback_reason: str | None, + request_start: float, + finished_at: float, + sample_rate: int | None = None, + audio_bytes: int = 0, + pack_ms: float = 0.0, + chunk_count: int = 0, + stream_total_bytes: int = 0, + first_chunk_ms: float | None = None, + ) -> Dict[str, Any]: + request_total_ms = max(0.0, (finished_at - request_start) * 1000.0) + legacy_infer_ms = max(0.0, request_total_ms - pack_ms) + return { + "backend": backend, + "backend_mode": backend, + "fallback_reason": fallback_reason, + "request_total_ms": request_total_ms, + "prepare_ms": 0.0, + "queue_wait_ms": 0.0, + "prefill_ms": 0.0, + "merge_ms": 0.0, + "decode_ms": 0.0, + "finalize_wait_ms": 0.0, + "synth_ms": 0.0, + "pack_ms": pack_ms, + "worker_total_ms": legacy_infer_ms, + "request_other_ms": 0.0, + "legacy_infer_ms": legacy_infer_ms, + "sample_rate": int(sample_rate) if sample_rate is not None else None, + "audio_bytes": int(audio_bytes), + "chunk_count": int(chunk_count), + "stream_total_bytes": int(stream_total_bytes), + "first_chunk_ms": None if first_chunk_ms is None else float(first_chunk_ms), + } + + def _build_scheduler_submit_profile( + self, + *, + backend: str, + request_start: float, + response_ready_at: float, + audio_bytes: int, + sample_rate: int, + prepare_spec_build_ms: float, + prepare_wall_ms: float, + prepare_executor_queue_ms: float, + prepare_executor_run_ms: float, + prepare_profile_total_ms: float, + prepare_profile_wall_ms: float, + prepare_other_ms: float, + engine_policy_wait_ms: float, + api_after_prepare_ms: float, + api_wait_result_ms: float, + pack_ms: float, + response_overhead_ms: float, + worker_profile: Dict[str, Any], + ) -> Dict[str, Any]: + worker_total_ms = float(worker_profile.get("worker_total_ms", 0.0)) + request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) + request_other_ms = max( + 0.0, + request_total_ms + - prepare_wall_ms + - engine_policy_wait_ms + - api_after_prepare_ms + - worker_total_ms + - api_wait_result_ms + - pack_ms, + ) + result = { + "backend": backend, + "backend_mode": backend, + "audio_bytes": int(audio_bytes), + "sample_rate": int(sample_rate), + "prepare_spec_build_ms": prepare_spec_build_ms, + "prepare_ms": prepare_wall_ms, + "prepare_wall_ms": prepare_wall_ms, + "prepare_executor_queue_ms": prepare_executor_queue_ms, + "prepare_executor_run_ms": prepare_executor_run_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "prepare_profile_wall_ms": prepare_profile_wall_ms, + "prepare_other_ms": prepare_other_ms, + "engine_policy_wait_ms": float(engine_policy_wait_ms), + "api_after_prepare_ms": api_after_prepare_ms, + "api_wait_result_ms": api_wait_result_ms, + "pack_ms": pack_ms, + "response_overhead_ms": response_overhead_ms, + "request_total_ms": request_total_ms, + "request_other_ms": request_other_ms, + } + result.update({key: value for key, value in worker_profile.items()}) + return result + + @staticmethod + def _format_ms_header(value: Any) -> str: + return f"{float(value):.3f}" + + def _build_scheduler_submit_headers( + self, + *, + request_id: str, + media_type: str, + sample_rate: int, + profile: Dict[str, Any], + ) -> Dict[str, str]: + prepare_profile = dict(profile.get("prepare_profile", {})) + headers = { + "X-Request-Id": request_id, + "X-Semantic-Len": str(int(profile.get("semantic_len", 0))), + "X-Finish-Reason": str(profile.get("finish_reason", "unknown")), + "X-Queue-Wait-Ms": self._format_ms_header(profile.get("queue_wait_ms", 0.0)), + "X-Decode-Admission-Wait-Ms": self._format_ms_header(profile.get("decode_admission_wait_ms", 0.0)), + "X-Engine-Policy-Wait-Ms": self._format_ms_header(profile.get("engine_policy_wait_ms", 0.0)), + "X-Engine-Dispatch-Wait-Ms": self._format_ms_header(profile.get("engine_dispatch_wait_ms", 0.0)), + "X-Prepare-Ms": self._format_ms_header(profile.get("prepare_wall_ms", 0.0)), + "X-Prepare-Wall-Ms": self._format_ms_header(profile.get("prepare_wall_ms", 0.0)), + "X-Prepare-Spec-Build-Ms": self._format_ms_header(profile.get("prepare_spec_build_ms", 0.0)), + "X-Prepare-Executor-Queue-Ms": self._format_ms_header(profile.get("prepare_executor_queue_ms", 0.0)), + "X-Prepare-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("prepare_admission_wait_ms", 0.0)), + "X-Prepare-Executor-Run-Ms": self._format_ms_header(profile.get("prepare_executor_run_ms", 0.0)), + "X-Prepare-Profile-Total-Ms": self._format_ms_header(profile.get("prepare_profile_total_ms", 0.0)), + "X-Prepare-Profile-Wall-Ms": self._format_ms_header(profile.get("prepare_profile_wall_ms", 0.0)), + "X-Prepare-Other-Ms": self._format_ms_header(profile.get("prepare_other_ms", 0.0)), + "X-Api-After-Prepare-Ms": self._format_ms_header(profile.get("api_after_prepare_ms", 0.0)), + "X-Prefill-Ms": self._format_ms_header(profile.get("prefill_ms", 0.0)), + "X-Merge-Ms": self._format_ms_header(profile.get("merge_ms", 0.0)), + "X-Decode-Ms": self._format_ms_header(profile.get("decode_ms", 0.0)), + "X-Finalize-Wait-Ms": self._format_ms_header(profile.get("finalize_wait_ms", 0.0)), + "X-Synth-Ms": self._format_ms_header(profile.get("synth_ms", 0.0)), + "X-Worker-Residual-Ms": self._format_ms_header(profile.get("worker_residual_ms", 0.0)), + "X-Worker-Other-Ms": self._format_ms_header(profile.get("worker_other_ms", 0.0)), + "X-Pack-Ms": self._format_ms_header(profile.get("pack_ms", 0.0)), + "X-Worker-Total-Ms": self._format_ms_header(profile.get("worker_total_ms", 0.0)), + "X-Api-Wait-Result-Ms": self._format_ms_header(profile.get("api_wait_result_ms", 0.0)), + "X-Decode-Steps": str(int(profile.get("decode_steps", 0))), + "X-Sample-Rate": str(int(sample_rate)), + "X-Response-Overhead-Ms": self._format_ms_header(profile.get("response_overhead_ms", 0.0)), + "X-Request-Other-Ms": self._format_ms_header(profile.get("request_other_ms", 0.0)), + "X-Request-Total-Ms": self._format_ms_header(profile.get("request_total_ms", 0.0)), + } + headers.update( + { + "X-Prepare-Prompt-Text-Ms": self._format_ms_header(prepare_profile.get("prompt_text_features_ms", 0.0)), + "X-Prepare-Target-Text-Ms": self._format_ms_header(prepare_profile.get("text_features_ms", 0.0)), + "X-Prepare-Prompt-Text-CPU-Preprocess-Ms": self._format_ms_header(prepare_profile.get("prompt_text_cpu_preprocess_ms", 0.0)), + "X-Prepare-Target-Text-CPU-Preprocess-Ms": self._format_ms_header(prepare_profile.get("text_cpu_preprocess_ms", 0.0)), + "X-Prepare-Prompt-Text-CPU-Queue-Ms": self._format_ms_header(prepare_profile.get("prompt_text_cpu_queue_ms", 0.0)), + "X-Prepare-Target-Text-CPU-Queue-Ms": self._format_ms_header(prepare_profile.get("text_cpu_queue_ms", 0.0)), + "X-Prepare-Prompt-Text-Feature-Queue-Ms": self._format_ms_header(prepare_profile.get("prompt_text_feature_queue_ms", 0.0)), + "X-Prepare-Target-Text-Feature-Queue-Ms": self._format_ms_header(prepare_profile.get("text_feature_queue_ms", 0.0)), + "X-Prepare-Prompt-Bert-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_admission_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_admission_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_queue_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_queue_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_batch_collect_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_batch_collect_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Forward-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_forward_ms", 0.0)), + "X-Prepare-Target-Bert-Forward-Ms": self._format_ms_header(prepare_profile.get("text_bert_forward_ms", 0.0)), + "X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))), + "X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))), + "X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))), + "X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))), + "X-Prepare-Prompt-Bert-High-Pressure-Peak": str(int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))), + "X-Prepare-Target-Bert-High-Pressure-Peak": str(int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))), + "X-Prepare-Prompt-Bert-Batch-Window-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_batch_window_ms", 0.0)), + "X-Prepare-Target-Bert-Batch-Window-Ms": self._format_ms_header(prepare_profile.get("text_bert_batch_window_ms", 0.0)), + "X-Prepare-Text-Pair-Wall-Ms": self._format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)), + "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), + "X-Prepare-Engine-GPU-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)), + "X-Prepare-Audio-Load-Ms": self._format_ms_header(prepare_profile.get("audio_load_ms", 0.0)), + "X-Prepare-Audio-Stage-Wait-Ms": self._format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)), + "X-Prepare-Prompt-Semantic-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)), + "X-Prepare-Prompt-Semantic-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_wait_ms", 0.0)), + "X-Prepare-Prompt-Semantic-CPU-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "X-Prepare-Prompt-Semantic-Forward-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_forward_ms", 0.0)), + "X-Prepare-Ref-Spec-Ms": self._format_ms_header(prepare_profile.get("ref_spec_ms", 0.0)), + "X-Prepare-Ref-Spec-Wait-Ms": self._format_ms_header(prepare_profile.get("ref_spec_wait_ms", 0.0)), + "X-Prepare-Ref-Bundle-Ms": self._format_ms_header(prepare_profile.get("ref_audio_bundle_ms", 0.0)), + "X-Prepare-Tensorize-Ms": self._format_ms_header(prepare_profile.get("tensorize_ms", 0.0)), + "X-Prepare-Inflight-On-Enter": str(int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0))), + "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), + } + ) + return headers + + def _build_scheduler_debug_request_profile( + self, + *, + state: T2SRequestState, + item: T2SFinishedItem, + batch_request_count: int, + prepare_batch_wall_ms: float, + decode_batch_wall_ms: float, + batch_request_total_ms: float, + ) -> Dict[str, Any]: + prepare_profile = dict(state.prepare_profile) + prepare_wall_ms = float(prepare_profile.get("wall_total_ms", 0.0)) + return { + "backend": "scheduler_debug", + "backend_mode": "scheduler_debug", + "batch_request_count": int(batch_request_count), + "batch_prepare_wall_ms": float(prepare_batch_wall_ms), + "batch_decode_wall_ms": float(decode_batch_wall_ms), + "batch_request_total_ms": float(batch_request_total_ms), + "prepare_ms": prepare_wall_ms, + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": float(prepare_profile.get("wall_total_ms", prepare_wall_ms)), + "prepare_profile": prepare_profile, + "decode_steps": int(item.finish_idx), + "finish_idx": int(item.finish_idx), + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_reason": item.finish_reason, + "norm_text": state.norm_text, + "norm_prompt_text": state.norm_prompt_text, + } + + @staticmethod + def _build_scheduler_debug_batch_profile( + *, + request_count: int, + max_steps: int, + prepare_batch_wall_ms: float, + decode_batch_wall_ms: float, + request_total_ms: float, + finished_items: Sequence[T2SFinishedItem], + ) -> Dict[str, Any]: + finish_reason_counts: Dict[str, int] = {} + total_semantic_len = 0 + for item in finished_items: + finish_reason_counts[item.finish_reason] = finish_reason_counts.get(item.finish_reason, 0) + 1 + total_semantic_len += int(item.semantic_tokens.shape[0]) + return { + "request_count": int(request_count), + "max_steps": int(max_steps), + "prepare_batch_wall_ms": float(prepare_batch_wall_ms), + "decode_batch_wall_ms": float(decode_batch_wall_ms), + "request_total_ms": float(request_total_ms), + "total_semantic_len": int(total_semantic_len), + "finish_reason_counts": finish_reason_counts, + } + + def _normalize_lang(self, value: str | None) -> str | None: + if value in [None, ""]: + return value + return str(value).lower() + + @staticmethod + def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: + totals: Dict[str, float] = {} + for item in items: + for key, value in item.items(): + if isinstance(value, (int, float)): + totals[key] = totals.get(key, 0.0) + float(value) + return totals + + def _apply_default_reference(self, req: dict) -> dict: + normalized = dict(req) + default_ref = self.reference_registry.get_default() + if normalized.get("ref_audio_path") in [None, ""] and default_ref.ref_audio_path not in [None, ""]: + normalized["ref_audio_path"] = default_ref.ref_audio_path + if "text_lang" in normalized: + normalized["text_lang"] = self._normalize_lang(normalized.get("text_lang")) + if "prompt_lang" in normalized: + normalized["prompt_lang"] = self._normalize_lang(normalized.get("prompt_lang")) + return normalized + + def check_params(self, req: dict) -> Optional[str]: + text = req.get("text", "") + text_lang = req.get("text_lang", "") + ref_audio_path = req.get("ref_audio_path", "") + media_type = req.get("media_type", "wav") + prompt_lang = req.get("prompt_lang", "") + text_split_method = req.get("text_split_method", "cut5") + + if ref_audio_path in [None, ""]: + return "ref_audio_path is required" + if text in [None, ""]: + return "text is required" + if text_lang in [None, ""]: + return "text_lang is required" + if text_lang.lower() not in self.tts.configs.languages: + return f"text_lang: {text_lang} is not supported in version {self.tts.configs.version}" + if prompt_lang in [None, ""]: + return "prompt_lang is required" + if prompt_lang.lower() not in self.tts.configs.languages: + return f"prompt_lang: {prompt_lang} is not supported in version {self.tts.configs.version}" + if media_type not in ["wav", "raw", "ogg", "aac"]: + return f"media_type: {media_type} is not supported" + if text_split_method not in self.cut_method_names: + return f"text_split_method:{text_split_method} is not supported" + return None + + @staticmethod + def _base_request_defaults() -> Dict[str, Any]: + return { + "request_id": None, + "text": None, + "text_lang": None, + "ref_audio_path": None, + "aux_ref_audio_paths": None, + "prompt_text": "", + "prompt_lang": None, + "top_k": 15, + "top_p": 1.0, + "temperature": 1.0, + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": 1.0, + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": "wav", + "streaming_mode": False, + "return_fragment": False, + "fixed_length_chunk": False, + "response_streaming": False, + "parallel_infer": False, + "repetition_penalty": 1.35, + "sample_steps": 32, + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + "early_stop_num": -1, + "ready_step": 0, + "timeout_sec": None, + } + + def _normalize_engine_request( + self, + payload: dict | NormalizedEngineRequest, + *, + request_id: str | None = None, + normalize_streaming: bool = False, + error_prefix: str = "request 参数非法: ", + ) -> NormalizedEngineRequest: + if isinstance(payload, NormalizedEngineRequest): + normalized_payload = payload.to_payload() + else: + normalized_payload = self._base_request_defaults() + normalized_payload.update(dict(payload)) + if request_id not in [None, ""]: + normalized_payload["request_id"] = str(request_id) + elif normalized_payload.get("request_id") in [None, ""]: + raise ValueError("request_id is required after normalization") + normalized_payload = self._apply_default_reference(normalized_payload) + if normalize_streaming: + normalized_payload = self._normalize_streaming_mode(normalized_payload) + error = self.check_params(normalized_payload) + if error is not None: + raise ValueError(f"{error_prefix}{error}") + timeout_sec = normalized_payload.get("timeout_sec") + if timeout_sec in [None, ""]: + parsed_timeout = None + else: + parsed_timeout = float(timeout_sec) + aux_ref_audio_paths = normalized_payload.get("aux_ref_audio_paths") + if aux_ref_audio_paths in [None, "", []]: + normalized_aux_ref_audio_paths = None + else: + normalized_aux_ref_audio_paths = [str(item) for item in aux_ref_audio_paths] + return NormalizedEngineRequest( + request_id=str(normalized_payload["request_id"]), + text=str(normalized_payload["text"]), + text_lang=str(normalized_payload["text_lang"]), + ref_audio_path=str(normalized_payload["ref_audio_path"]), + prompt_lang=str(normalized_payload["prompt_lang"]), + prompt_text="" if normalized_payload.get("prompt_text") is None else str(normalized_payload.get("prompt_text")), + aux_ref_audio_paths=normalized_aux_ref_audio_paths, + top_k=int(normalized_payload["top_k"]), + top_p=float(normalized_payload["top_p"]), + temperature=float(normalized_payload["temperature"]), + repetition_penalty=float(normalized_payload["repetition_penalty"]), + early_stop_num=int(normalized_payload.get("early_stop_num", -1)), + ready_step=int(normalized_payload.get("ready_step", 0)), + text_split_method=str(normalized_payload["text_split_method"]), + batch_size=int(normalized_payload["batch_size"]), + batch_threshold=float(normalized_payload["batch_threshold"]), + split_bucket=bool(normalized_payload["split_bucket"]), + speed_factor=float(normalized_payload["speed_factor"]), + fragment_interval=float(normalized_payload["fragment_interval"]), + seed=int(normalized_payload["seed"]), + media_type=str(normalized_payload["media_type"]), + streaming_mode=normalized_payload["streaming_mode"], + return_fragment=bool(normalized_payload.get("return_fragment", False)), + fixed_length_chunk=bool(normalized_payload.get("fixed_length_chunk", False)), + response_streaming=bool(normalized_payload.get("response_streaming", False)), + parallel_infer=bool(normalized_payload["parallel_infer"]), + sample_steps=int(normalized_payload["sample_steps"]), + super_sampling=bool(normalized_payload["super_sampling"]), + overlap_length=int(normalized_payload["overlap_length"]), + min_chunk_length=int(normalized_payload["min_chunk_length"]), + timeout_sec=parsed_timeout, + ) + + @staticmethod + def _normalize_streaming_mode(req: dict) -> dict: + normalized = dict(req) + streaming_mode = normalized.get("streaming_mode", False) + return_fragment = normalized.get("return_fragment", False) + if streaming_mode is False: + normalized["streaming_mode"] = False + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 0: + normalized["streaming_mode"] = False + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 1 or streaming_mode is True: + normalized["streaming_mode"] = False + normalized["return_fragment"] = True + normalized["fixed_length_chunk"] = False + elif streaming_mode == 2: + normalized["streaming_mode"] = True + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 3: + normalized["streaming_mode"] = True + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = True + else: + raise ValueError("the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)") + normalized["response_streaming"] = bool(normalized["streaming_mode"] or normalized["return_fragment"] or return_fragment) + return normalized + + @staticmethod + def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: + return aux_ref_audio_paths not in [None, [], ()] + + def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: + if normalized.response_streaming: + if normalized.return_fragment or normalized.fixed_length_chunk: + return "legacy_direct_fragment", "fragment_streaming_mode" + return "legacy_direct_streaming", "streaming_mode" + if self._is_aux_ref_enabled(normalized.aux_ref_audio_paths): + return "legacy_direct_aux_ref", "aux_ref_audio_paths" + if normalized.super_sampling: + return "legacy_direct_super_sampling", "super_sampling" + if normalized.prompt_text in [None, ""]: + return "legacy_direct_missing_prompt", "missing_prompt_text" + return "scheduler_v1_direct", None + + def _iter_legacy_direct_tts_bytes( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> Generator[bytes, None, None]: + payload = normalized.to_payload() + media_type = normalized.media_type + request_id = normalized.request_id + request_start = time.perf_counter() + chunk_count = 0 + stream_total_bytes = 0 + first_chunk_ms: float | None = None + self._update_request_state( + request_id, + EngineStatus.ACTIVE_DECODE, + {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, + ) + try: + with self.direct_tts_lock: + tts_generator = self.tts.run(payload) + first_chunk = True + current_media_type = media_type + for sr, chunk in tts_generator: + if first_chunk: + first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) + self._update_request_state( + request_id, + EngineStatus.STREAMING, + { + "backend": backend, + "backend_mode": backend, + "fallback_reason": fallback_reason, + "sample_rate": int(sr), + }, + ) + if first_chunk and media_type == "wav": + header = wave_header_chunk(sample_rate=sr) + chunk_count += 1 + stream_total_bytes += len(header) + yield header + current_media_type = "raw" + first_chunk = False + elif first_chunk: + first_chunk = False + packed_chunk = pack_audio(BytesIO(), chunk, sr, current_media_type).getvalue() + chunk_count += 1 + stream_total_bytes += len(packed_chunk) + yield packed_chunk + except Exception as exc: + self._fail_request_state(request_id, str(exc)) + raise + self._complete_request_state( + request_id, + dict( + self._build_legacy_direct_profile( + backend=backend, + fallback_reason=fallback_reason, + request_start=request_start, + finished_at=time.perf_counter(), + audio_bytes=stream_total_bytes, + chunk_count=chunk_count, + stream_total_bytes=stream_total_bytes, + first_chunk_ms=first_chunk_ms, + ), + streaming_completed=True, + ), + ) + + def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: + if isinstance(req, NormalizedEngineRequest): + normalized = req + else: + normalized = self._normalize_engine_request( + req, + request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), + normalize_streaming=True, + ) + backend, _ = self._select_direct_backend(normalized) + return backend == "scheduler_v1_direct" + + def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: + payload = normalized.to_payload() if isinstance(normalized, NormalizedEngineRequest) else normalized + return self.tts.text_preprocessor.pre_seg_text( + str(payload["text"]), + str(payload["text_lang"]), + str(payload.get("text_split_method", "cut5")), + ) + + def _build_segment_request( + self, + normalized: NormalizedEngineRequest, + *, + request_id: str, + text: str, + ) -> NormalizedEngineRequest: + payload = normalized.to_payload() + payload["request_id"] = request_id + payload["text"] = text + payload["streaming_mode"] = False + payload["return_fragment"] = False + payload["fixed_length_chunk"] = False + payload["response_streaming"] = False + return self._normalize_engine_request(payload, error_prefix="segment request 参数非法: ") + + async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: + request_start = time.perf_counter() + request_id = normalized.request_id + media_type = normalized.media_type + segment_texts = self._segment_direct_text(normalized) + if not segment_texts: + raise ValueError("text preprocessing returned no valid segments") + self._update_request_state( + request_id, + EngineStatus.CPU_PREPARING, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)}, + ) + segment_specs: List[SchedulerRequestSpec] = [] + for segment_index, segment_text in enumerate(segment_texts): + segment_request = self._build_segment_request( + normalized, + request_id=f"{request_id}_seg_{segment_index:03d}", + text=segment_text, + ) + segment_specs.append(self.build_scheduler_submit_spec(segment_request)) + + prepared_items = await asyncio.gather( + *[ + self._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=time.perf_counter(), + engine_request_id=None, + ) + for spec in segment_specs + ] + ) + prepare_profiles: List[Dict[str, Any]] = [] + loop = asyncio.get_running_loop() + done_futures: List[asyncio.Future] = [] + self._update_request_state( + request_id, + EngineStatus.READY_FOR_PREFILL, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_specs)}, + ) + for spec, (state, prepare_exec_started_at, prepare_exec_finished_at) in zip(segment_specs, prepared_items): + prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) + prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms)) + prepare_profiles.append( + { + "request_id": spec.request_id, + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "prepare_profile": dict(state.prepare_profile), + } + ) + done_future = loop.create_future() + done_futures.append(done_future) + await self._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=float(normalized.speed_factor), + sample_steps=int(normalized.sample_steps), + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, + engine_request_id=None, + timeout_sec=normalized.timeout_sec, + ) + self._update_request_state( + request_id, + EngineStatus.ACTIVE_DECODE, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, + ) + timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0) + jobs: List[SchedulerPendingJob] = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=timeout_sec)) + for profile_item, job in zip(prepare_profiles, jobs): + profile_item["engine_policy_wait_ms"] = float(job.engine_policy_wait_ms) + profile_item["engine_dispatch_wait_ms"] = float(job.engine_dispatch_wait_ms) + self._merge_request_state_profile( + request_id, + { + "engine_policy_wait_ms": sum(float(job.engine_policy_wait_ms) for job in jobs), + "engine_dispatch_wait_ms": sum(float(job.engine_dispatch_wait_ms) for job in jobs), + "prepare_aggregate": self._aggregate_numeric_dicts( + [item["prepare_profile"] for item in prepare_profiles] + ), + }, + ) + + sample_rate: int | None = None + audio_parts: List[np.ndarray] = [] + worker_profiles: List[Dict[str, Any]] = [] + fragment_interval = float(normalized.fragment_interval) + silence_chunk: Optional[np.ndarray] = None + for job in jobs: + if job.error is not None: + raise RuntimeError(job.error) + if job.audio_data is None or job.sample_rate is None or job.result is None: + raise RuntimeError(f"{job.request_id} finished without audio result") + if sample_rate is None: + sample_rate = int(job.sample_rate) + silence_samples = int(fragment_interval * float(sample_rate)) + if silence_samples > 0: + silence_chunk = np.zeros(silence_samples, dtype=np.int16) + elif int(job.sample_rate) != sample_rate: + raise RuntimeError("segment sample rate mismatch") + audio_parts.append(job.audio_data) + if silence_chunk is not None: + audio_parts.append(silence_chunk.copy()) + worker_profiles.append(dict(job.result)) + if sample_rate is None or not audio_parts: + raise RuntimeError("direct scheduler backend produced no audio") + self._update_request_state( + request_id, + EngineStatus.FINALIZING, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, + ) + merged_audio = np.concatenate(audio_parts, axis=0) + pack_start = time.perf_counter() + audio_bytes = pack_audio(BytesIO(), merged_audio, sample_rate, media_type).getvalue() + pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0) + direct_profile = self._build_direct_scheduler_profile( + backend="scheduler_v1_direct", + request_start=request_start, + response_ready_at=time.perf_counter(), + audio_bytes=len(audio_bytes), + sample_rate=int(sample_rate), + segment_texts=segment_texts, + prepare_profiles=prepare_profiles, + worker_profiles=worker_profiles, + pack_ms=pack_ms, + response_overhead_ms=0.0, + ) + self._complete_request_state( + request_id, + dict(direct_profile, streaming_completed=False), + ) + return DirectTTSExecution( + media_type=media_type, + streaming=False, + audio_bytes=audio_bytes, + request_id=request_id, + ) + + def _run_legacy_direct_tts_blocking( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + normalized_payload = normalized.to_payload() + request_id = normalized.request_id + media_type = normalized.media_type + request_start = time.perf_counter() + self._update_request_state( + request_id, + EngineStatus.ACTIVE_DECODE, + {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, + ) + with self.direct_tts_lock: + tts_generator = self.tts.run(normalized_payload) + try: + sr, audio_data = next(tts_generator) + except Exception as exc: + self._fail_request_state(request_id, str(exc)) + raise + self._update_request_state( + request_id, + EngineStatus.FINALIZING, + {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, + ) + pack_start = time.perf_counter() + packed_audio = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0) + self._complete_request_state( + request_id, + dict( + self._build_legacy_direct_profile( + backend=backend, + fallback_reason=fallback_reason, + request_start=request_start, + finished_at=time.perf_counter(), + sample_rate=int(sr), + audio_bytes=len(packed_audio), + pack_ms=pack_ms, + ), + streaming_completed=False, + ), + ) + return DirectTTSExecution( + media_type=media_type, + streaming=False, + audio_bytes=packed_audio, + request_id=request_id, + ) + + async def _run_direct_tts_via_legacy_backend( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + if normalized.response_streaming: + return DirectTTSExecution( + media_type=normalized.media_type, + streaming=True, + audio_generator=self._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ), + request_id=normalized.request_id, + ) + return await asyncio.to_thread( + self._run_legacy_direct_tts_blocking, + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: + normalized = self._normalize_engine_request( + req, + request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), + normalize_streaming=True, + error_prefix="", + ) + request_id = normalized.request_id + media_type = normalized.media_type + backend, fallback_reason = self._select_direct_backend(normalized) + self._register_request_state( + request_id=request_id, + api_mode="tts", + backend=backend, + media_type=media_type, + response_streaming=bool(normalized.response_streaming), + deadline_ts=( + time.perf_counter() + float(normalized.timeout_sec) + if normalized.timeout_sec is not None + else None + ), + meta=self._build_request_meta(normalized.to_payload()), + ) + self._update_request_state( + request_id, + EngineStatus.VALIDATED, + { + "request_source": "direct_tts", + "selected_backend": backend, + "fallback_reason": fallback_reason, + }, + ) + if backend == "scheduler_v1_direct": + try: + return await self._run_direct_tts_via_scheduler(normalized) + except Exception as exc: + self._fail_request_state(request_id, str(exc)) + raise + return await self._run_direct_tts_via_legacy_backend( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + def run_direct_tts(self, req: dict) -> DirectTTSExecution: + normalized = self._normalize_engine_request( + req, + request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), + normalize_streaming=True, + error_prefix="", + ) + request_id = normalized.request_id + media_type = normalized.media_type + backend, fallback_reason = self._select_direct_backend(normalized) + if not self._has_active_request(request_id): + self._register_request_state( + request_id=request_id, + api_mode="tts", + backend=backend, + media_type=media_type, + response_streaming=bool(normalized.response_streaming), + meta=self._build_request_meta(normalized.to_payload()), + ) + self._update_request_state( + request_id, + EngineStatus.VALIDATED, + { + "request_source": "direct_tts", + "selected_backend": backend, + "fallback_reason": fallback_reason, + }, + ) + if backend != "scheduler_v1_direct": + if normalized.response_streaming: + return DirectTTSExecution( + media_type=media_type, + streaming=True, + audio_generator=self._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ), + request_id=request_id, + ) + return self._run_legacy_direct_tts_blocking( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + if normalized.response_streaming: + return DirectTTSExecution( + media_type=media_type, + streaming=True, + audio_generator=self._iter_legacy_direct_tts_bytes( + normalized, + backend="legacy_direct_sync_compat", + fallback_reason="sync_direct_compat", + ), + request_id=request_id, + ) + return self._run_legacy_direct_tts_blocking( + normalized, + backend="legacy_direct_sync_compat", + fallback_reason="sync_direct_compat", + ) + + def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: + specs: List[SchedulerRequestSpec] = [] + for index, payload in enumerate(request_items): + normalized = self._normalize_engine_request( + payload, + request_id=str(payload.get("request_id") or f"req_{index:03d}"), + error_prefix=f"request[{index}] 参数非法: ", + ) + specs.append(normalized.to_scheduler_spec()) + return specs + + def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: + normalized = self._normalize_engine_request( + payload, + request_id=( + payload.request_id + if isinstance(payload, NormalizedEngineRequest) + else str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}") + ), + ) + return normalized.to_scheduler_spec() + + @staticmethod + def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: + return [ + { + "request_id": state.request_id, + "ready_step": int(state.ready_step), + "ref_audio_path": str(state.ref_audio_path), + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "all_phone_len": int(state.all_phones.shape[0]), + "bert_len": int(state.all_bert_features.shape[-1]), + "norm_text": state.norm_text, + } + for state in states + ] + + @staticmethod + def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: + return [ + { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + } + for item in items + ] + + async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: + request_start = time.perf_counter() + set_scheduler_seed(seed) + specs = self.build_scheduler_request_specs(request_items) + request_ids = [spec.request_id for spec in specs] + for spec in specs: + self._register_request_state( + request_id=spec.request_id, + api_mode="scheduler_debug", + backend="scheduler_debug", + media_type="wav", + response_streaming=False, + meta={ + "text_len": len(spec.text), + "prompt_text_len": len(spec.prompt_text), + "text_lang": spec.text_lang, + "prompt_lang": spec.prompt_lang, + "ref_audio_path": str(spec.ref_audio_path), + "ready_step": int(spec.ready_step), + }, + ) + self._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"}) + self._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, None) + prepare_started_at = time.perf_counter() + try: + states = await self.scheduler_worker.prepare_states_batch_async(specs) + except Exception as exc: + for request_id in request_ids: + self._fail_request_state(request_id, str(exc)) + raise + prepare_finished_at = time.perf_counter() + prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0) + for state in states: + self._update_request_state( + state.request_id, + EngineStatus.ACTIVE_DECODE, + { + "prepare_profile": dict(state.prepare_profile), + "norm_text": state.norm_text, + "norm_prompt_text": state.norm_prompt_text, + }, + ) + decode_started_at = time.perf_counter() + try: + finished = run_scheduler_continuous(self.tts.t2s_model.model, states, max_steps=int(max_steps)) + except Exception as exc: + for request_id in request_ids: + self._fail_request_state(request_id, str(exc)) + raise + decode_finished_at = time.perf_counter() + decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0) + request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0) + finished_map = {item.request_id: item for item in finished} + request_profiles: List[Dict[str, Any]] = [] + for state in states: + item = finished_map.get(state.request_id) + if item is None: + self._fail_request_state(state.request_id, "scheduler_debug finished without result") + continue + request_profile = self._build_scheduler_debug_request_profile( + state=state, + item=item, + batch_request_count=len(states), + prepare_batch_wall_ms=prepare_batch_wall_ms, + decode_batch_wall_ms=decode_batch_wall_ms, + batch_request_total_ms=request_total_ms, + ) + request_profiles.append( + { + "request_id": state.request_id, + "profile": dict(request_profile), + } + ) + self._complete_request_state( + state.request_id, + dict(request_profile), + ) + return SchedulerDebugExecution( + payload={ + "message": "success", + "request_count": len(states), + "max_steps": int(max_steps), + "batch_profile": self._build_scheduler_debug_batch_profile( + request_count=len(states), + max_steps=int(max_steps), + prepare_batch_wall_ms=prepare_batch_wall_ms, + decode_batch_wall_ms=decode_batch_wall_ms, + request_total_ms=request_total_ms, + finished_items=finished, + ), + "requests": self.summarize_scheduler_states(states), + "finished": self.summarize_scheduler_finished(finished), + "request_profiles": request_profiles, + "request_traces": self._collect_request_summaries(request_ids), + } + ) + + async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: + request_start = time.perf_counter() + prepare_start = request_start + normalized = self._normalize_engine_request( + payload, + request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"), + ) + spec = self.build_scheduler_submit_spec(normalized) + deadline_ts = None + timeout_sec = normalized.timeout_sec + if timeout_sec is not None: + try: + deadline_ts = request_start + float(timeout_sec) + except Exception: + deadline_ts = None + self._register_request_state( + request_id=spec.request_id, + api_mode="scheduler_submit", + backend="scheduler_v1", + media_type=normalized.media_type, + response_streaming=False, + deadline_ts=deadline_ts, + meta=self._build_request_meta(normalized.to_payload()), + ) + self._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_submit"}) + spec_ready_at = time.perf_counter() + prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) + self._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, {"prepare_spec_build_ms": prepare_spec_build_ms}) + try: + state, prepare_exec_started_at, prepare_exec_finished_at = await self._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=spec_ready_at, + engine_request_id=spec.request_id, + ) + except Exception as exc: + self._fail_request_state(spec.request_id, str(exc)) + raise + prepare_wall_ms = max(0.0, (prepare_exec_finished_at - spec_ready_at) * 1000.0) + prepare_executor_queue_ms = max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0) + prepare_executor_run_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) + prepare_profile = dict(state.prepare_profile) + prepare_profile_total_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) + prepare_profile_wall_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) + prepare_other_ms = max(0.0, prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_executor_run_ms) + self._update_request_state( + spec.request_id, + EngineStatus.READY_FOR_PREFILL, + { + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "prepare_profile": prepare_profile, + }, + ) + api_after_prepare_start = time.perf_counter() + loop = asyncio.get_running_loop() + done_future = loop.create_future() + await self._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=float(normalized.speed_factor), + sample_steps=int(normalized.sample_steps), + media_type=normalized.media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, + engine_request_id=spec.request_id, + timeout_sec=normalized.timeout_sec, + ) + api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0) + try: + job = await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)) + except Exception as exc: + self._fail_request_state(spec.request_id, str(exc)) + raise + wait_return_at = time.perf_counter() + if job.error is not None: + raise RuntimeError(job.error) + if job.audio_data is None or job.sample_rate is None or job.result is None: + self._fail_request_state(spec.request_id, f"{job.request_id} finished without audio result") + raise RuntimeError(f"{job.request_id} finished without audio result") + pack_start = time.perf_counter() + audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() + pack_end = time.perf_counter() + pack_ms = (pack_end - pack_start) * 1000.0 + api_wait_result_ms = 0.0 + if job.result_ready_time is not None: + api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0) + response_ready_at = time.perf_counter() + response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0) + submit_profile = self._build_scheduler_submit_profile( + backend="scheduler_v1", + request_start=request_start, + response_ready_at=response_ready_at, + audio_bytes=len(audio_data), + sample_rate=int(job.sample_rate), + prepare_spec_build_ms=prepare_spec_build_ms, + prepare_wall_ms=prepare_wall_ms, + prepare_executor_queue_ms=prepare_executor_queue_ms, + prepare_executor_run_ms=prepare_executor_run_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + prepare_profile_wall_ms=prepare_profile_wall_ms, + prepare_other_ms=prepare_other_ms, + engine_policy_wait_ms=float(job.result.get("engine_policy_wait_ms", 0.0)), + api_after_prepare_ms=api_after_prepare_ms, + api_wait_result_ms=api_wait_result_ms, + pack_ms=pack_ms, + response_overhead_ms=response_overhead_ms, + worker_profile=dict(job.result or {}), + ) + headers = self._build_scheduler_submit_headers( + request_id=job.request_id, + media_type=job.media_type, + sample_rate=int(job.sample_rate), + profile=submit_profile, + ) + self._merge_request_state_profile( + spec.request_id, + dict(submit_profile, response_headers_emitted=True), + ) + return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=f"audio/{job.media_type}", headers=headers) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_audio.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_audio.py new file mode 100644 index 00000000..5c3bd7a5 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_audio.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import subprocess +import threading +import wave +from io import BytesIO + +import numpy as np +import soundfile as sf +import torch + + +def set_scheduler_seed(seed: int): + if seed in ["", None]: + return + seed = int(seed) + if seed < 0: + return + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): + def handle_pack_ogg(): + with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + + stack_size = 4096 * 4096 + try: + threading.stack_size(stack_size) + pack_ogg_thread = threading.Thread(target=handle_pack_ogg) + pack_ogg_thread.start() + pack_ogg_thread.join() + except (RuntimeError, ValueError): + handle_pack_ogg() + return io_buffer + + +def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer.write(data.tobytes()) + return io_buffer + + +def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer = BytesIO() + sf.write(io_buffer, data, rate, format="wav") + return io_buffer + + +def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): + process = subprocess.Popen( + [ + "ffmpeg", + "-f", + "s16le", + "-ar", + str(rate), + "-ac", + "1", + "-i", + "pipe:0", + "-c:a", + "aac", + "-b:a", + "192k", + "-vn", + "-f", + "adts", + "pipe:1", + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, _ = process.communicate(input=data.tobytes()) + io_buffer.write(out) + return io_buffer + + +def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): + if media_type == "ogg": + io_buffer = pack_ogg(io_buffer, data, rate) + elif media_type == "aac": + io_buffer = pack_aac(io_buffer, data, rate) + elif media_type == "wav": + io_buffer = pack_wav(io_buffer, data, rate) + else: + io_buffer = pack_raw(io_buffer, data, rate) + io_buffer.seek(0) + return io_buffer + + +def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): + wav_buf = BytesIO() + with wave.open(wav_buf, "wb") as vfout: + vfout.setnchannels(channels) + vfout.setsampwidth(sample_width) + vfout.setframerate(sample_rate) + vfout.writeframes(frame_input) + wav_buf.seek(0) + return wav_buf.read() + + diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py new file mode 100644 index 00000000..536efbc5 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py @@ -0,0 +1,310 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner, EngineDispatchTask, EngineRequestState, EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob + + +class EngineBridgeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def request_registry(self): + return self.owner.request_registry + + @property + def engine_prepare_queue_owner(self): + return self.owner.engine_prepare_queue_owner + + @property + def engine_finalize_queue_owner(self): + return self.owner.engine_finalize_queue_owner + + @property + def engine_dispatch_queue_owner(self): + return self.owner.engine_dispatch_queue_owner + + @property + def engine_decode_runtime_owner(self): + return self.owner.engine_decode_runtime_owner + + @property + def engine_job_registry(self): + return self.owner.engine_job_registry + + @property + def scheduler_worker(self): + return self.owner.scheduler_worker + + @property + def engine_stage_coordinator(self): + return self.owner.engine_stage_coordinator + + @property + def engine_policy_arbiter(self): + return self.owner.engine_policy_arbiter + + def _register_request_state( + self, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + return self.request_registry.register( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=response_streaming, + deadline_ts=deadline_ts, + meta=meta, + ) + + def _update_request_state( + self, + request_id: str, + status: str, + extra: Optional[Dict[str, Any]] = None, + ) -> None: + self.request_registry.update(request_id, status, extra) + + def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.request_registry.merge_profile(request_id, extra) + + def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.request_registry.complete(request_id, extra) + + def _fail_request_state(self, request_id: str, error: str) -> None: + self.request_registry.fail(request_id, error) + + def _snapshot_request_registry(self) -> Dict[str, Any]: + return self.request_registry.snapshot() + + def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: + return self.engine_prepare_queue_owner.snapshot(max_request_ids=16) + + def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: + return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) + + def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: + return self.engine_dispatch_queue_owner.snapshot( + max_request_ids=16, + extra={"last_policy_snapshot": dict(self.owner.engine_dispatch_last_snapshot or {})}, + ) + + def _register_engine_job(self, job: SchedulerPendingJob) -> None: + self.engine_job_registry.register(job, keep_job=True) + + def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.engine_job_registry.get(request_id) + + def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.engine_job_registry.pop(request_id) + + def _snapshot_engine_job_registry(self) -> Dict[str, Any]: + return self.engine_job_registry.snapshot(max_request_ids=32) + + def _is_engine_drained(self) -> bool: + prepare_empty = self.engine_prepare_queue_owner.is_drained() + dispatch_empty = self.engine_dispatch_queue_owner.is_drained() + finalize_empty = self.engine_finalize_queue_owner.is_drained() + decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() + job_empty = self.engine_job_registry.is_empty() + worker_state = self.scheduler_worker.snapshot() + return bool( + prepare_empty + and dispatch_empty + and finalize_empty + and decode_pending_empty + and job_empty + and self.engine_decode_runtime_owner.get_active_batch() is None + and int(worker_state.get("prepare_inflight", 0)) <= 0 + and int(worker_state.get("finalize_inflight", 0)) <= 0 + and int(worker_state.get("finalize_pending", 0)) <= 0 + ) + + def _record_engine_job_done(self, request_id: str) -> None: + self.engine_job_registry.mark_finished_and_remove(request_id) + self.scheduler_worker.record_external_job_done(request_id) + + def _complete_engine_job( + self, + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + completion_bridge = self.scheduler_worker.completion_bridge + completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) + completion_bridge.complete_job( + job, + runtime_request_id=job.engine_request_id, + runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate), + on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid), + ) + + def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: + if not request_ids: + return + completion_bridge = self.scheduler_worker.completion_bridge + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is None: + continue + completion_bridge.fail_job( + job, + error=error, + on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid), + ) + + def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + for job in jobs: + job.prefill_ms += delta_ms + + def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + activate_request_ids: List[str] = [] + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is None: + continue + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) + job.decode_ms += delta_ms + job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None) + + def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items] + self._enqueue_worker_finished_for_finalize(tasks) + + def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: + return self.engine_decode_runtime_owner.snapshot_pending_queue_state() + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch) + + def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: + self.engine_decode_runtime_owner.refresh_state(last_event) + + def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: + if not snapshot: + return + if self.scheduler_worker.is_engine_decode_control_enabled(): + return + self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot) + + def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: + return self.engine_decode_runtime_owner.snapshot_state() + + def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: + return self.engine_policy_arbiter.snapshot_state() + + def _notify_engine_arbiter(self) -> None: + self.engine_policy_arbiter.notify() + + def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: + self.engine_stage_coordinator.decode_runtime_owner.enqueue_pending_job(job) + self._notify_engine_arbiter() + + def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.engine_stage_coordinator.decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch) + + def _peek_queue_age_ms(self, queue_name: str) -> float: + return self.engine_stage_coordinator.peek_queue_age_ms(queue_name) + + def _engine_has_pending_work(self) -> bool: + return self.engine_stage_coordinator.has_pending_work() + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + return await self.engine_stage_coordinator.prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, + ) + + def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + self.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks) + + def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + return self.engine_stage_coordinator.take_engine_finalize_batch_nonblocking() + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + return await self.engine_stage_coordinator.enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=speed_factor, + sample_steps=sample_steps, + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id, + timeout_sec=timeout_sec, + ) + + def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) + + def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage() + self.owner.engine_dispatch_last_snapshot = dict(policy_snapshot) + return stage, reason, policy_snapshot, worker_state + + def _run_engine_prepare_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_prepare_once() + + def _run_engine_finalize_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_finalize_once() + + def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + return self.engine_stage_coordinator.run_engine_dispatch_once(policy_snapshot, worker_state) + + def _run_engine_decode_runtime_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_decode_runtime_once() + + def _run_engine_arbiter_loop(self) -> None: + self.engine_stage_coordinator.run_engine_arbiter_loop() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py new file mode 100644 index 00000000..45178b1f --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import os +import threading +from typing import Any + +from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( + EngineArbiterConfig, + EngineDecodeRuntimeOwner, + EnginePolicyArbiterController, + EnginePolicyConfig, + EngineRequestRegistry, + EngineTaskQueueOwner, + ModelRegistry, + ReferenceRegistry, + RuntimeStateCallbacks, + SchedulerJobRegistry, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage import EngineStageCoordinator +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker + + +class EngineCompositionBuilder: + def __init__(self, owner: Any) -> None: + self.owner = owner + + def build(self, *, max_steps: int, micro_batch_wait_ms: int) -> None: + self._init_registries_and_locks() + self._init_worker(max_steps=max_steps, micro_batch_wait_ms=micro_batch_wait_ms) + self._init_policy_configs(micro_batch_wait_ms=micro_batch_wait_ms) + self._init_runtime_owners() + self._init_stage_coordinator() + self._init_arbiter() + self._init_facades() + self._start_arbiter_thread() + + def _init_registries_and_locks(self) -> None: + owner = self.owner + owner.reference_registry = ReferenceRegistry() + owner.model_registry = ModelRegistry( + t2s_weights_path=str(owner.tts.configs.t2s_weights_path), + vits_weights_path=str(owner.tts.configs.vits_weights_path), + ) + owner.request_registry = EngineRequestRegistry( + recent_limit=max(1, int(os.environ.get("GPTSOVITS_ENGINE_RECENT_REQUEST_LIMIT", "64"))) + ) + owner.engine_job_registry = SchedulerJobRegistry(threading.Lock()) + owner.direct_tts_lock = threading.RLock() + owner.management_lock = threading.RLock() + owner.engine_dispatch_last_snapshot = {} + + def _init_worker(self, *, max_steps: int, micro_batch_wait_ms: int) -> None: + owner = self.owner + owner.scheduler_worker = UnifiedSchedulerWorker( + owner.tts, + max_steps=max_steps, + micro_batch_wait_ms=micro_batch_wait_ms, + runtime_callbacks=RuntimeStateCallbacks( + update=owner._update_request_state, + complete=owner._complete_request_state, + fail=owner._fail_request_state, + decode_runtime_update=owner._update_engine_decode_runtime_state, + ), + external_finalize_submit=owner._enqueue_worker_finished_for_finalize, + ) + + def _init_policy_configs(self, *, micro_batch_wait_ms: int) -> None: + owner = self.owner + worker_capacity_limits = owner.scheduler_worker.get_capacity_limits() + prepare_max_inflight = int(owner.scheduler_worker.get_prepare_max_inflight()) + owner.engine_policy_config = EnginePolicyConfig( + enabled=owner._env_flag("GPTSOVITS_ENGINE_POLICY_ENABLE", True), + poll_wait_ms=max(1.0, owner._env_float("GPTSOVITS_ENGINE_POLICY_POLL_WAIT_MS", float(micro_batch_wait_ms))), + decode_backlog_soft_max=max( + 0, + owner._env_int( + "GPTSOVITS_ENGINE_POLICY_DECODE_BACKLOG_SOFT_MAX", + int(worker_capacity_limits["decode_backlog_max"]), + ), + ), + finalize_pending_soft_max=max( + 0, + owner._env_int( + "GPTSOVITS_ENGINE_POLICY_FINALIZE_PENDING_SOFT_MAX", + int(worker_capacity_limits["finalize_pending_max"]), + ), + ), + prepare_inflight_soft_max=max( + 0, + owner._env_int("GPTSOVITS_ENGINE_POLICY_PREPARE_INFLIGHT_SOFT_MAX", prepare_max_inflight), + ), + active_decode_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_DECODE_SOFT_MAX", 0)), + ready_for_prefill_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_READY_FOR_PREFILL_SOFT_MAX", 0)), + active_request_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_REQUEST_SOFT_MAX", 0)), + ) + owner.engine_arbiter_config = EngineArbiterConfig( + poll_wait_ms=max(1.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_POLL_WAIT_MS", float(micro_batch_wait_ms))), + decode_burst=max(1, owner._env_int("GPTSOVITS_ENGINE_ARBITER_DECODE_BURST", 4)), + prepare_aging_ms=max(0.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_PREPARE_AGING_MS", 10.0)), + finalize_aging_ms=max(0.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_FINALIZE_AGING_MS", 10.0)), + ) + + def _init_runtime_owners(self) -> None: + owner = self.owner + owner.engine_decode_runtime_owner = EngineDecodeRuntimeOwner( + get_decode_runtime_counters=owner.scheduler_worker.get_decode_runtime_counters, + get_micro_batch_wait_s=owner.scheduler_worker.get_micro_batch_wait_s, + ) + owner.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + owner.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + owner.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched") + + def _init_stage_coordinator(self) -> None: + owner = self.owner + owner.engine_stage_coordinator = EngineStageCoordinator( + tts=owner.tts, + scheduler_worker=owner.scheduler_worker, + prepare_queue_owner=owner.engine_prepare_queue_owner, + finalize_queue_owner=owner.engine_finalize_queue_owner, + dispatch_queue_owner=owner.engine_dispatch_queue_owner, + decode_runtime_owner=owner.engine_decode_runtime_owner, + update_request_state=owner._update_request_state, + merge_request_state_profile=owner._merge_request_state_profile, + fail_request_state=owner._fail_request_state, + get_engine_job=owner._get_engine_job, + register_engine_job=owner._register_engine_job, + fail_engine_jobs=owner._fail_engine_jobs, + complete_engine_job=owner._complete_engine_job, + add_engine_prefill_time=owner._add_engine_prefill_time, + add_engine_merge_time=owner._add_engine_merge_time, + add_engine_decode_time=owner._add_engine_decode_time, + enqueue_engine_finished_items=owner._enqueue_engine_finished_items, + snapshot_engine_dispatch_state=owner._snapshot_engine_dispatch_state, + snapshot_engine_decode_runtime_state=owner._snapshot_engine_decode_runtime_state, + ) + + def _init_arbiter(self) -> None: + owner = self.owner + owner.engine_policy_arbiter = EnginePolicyArbiterController( + policy_config=owner.engine_policy_config, + arbiter_config=owner.engine_arbiter_config, + snapshot_request_registry=owner._snapshot_request_registry, + get_worker_state=owner.get_scheduler_state, + snapshot_prepare_state=owner._snapshot_engine_prepare_state, + snapshot_finalize_state=owner._snapshot_engine_finalize_state, + snapshot_dispatch_state=owner._snapshot_engine_dispatch_state, + snapshot_decode_runtime_state=owner._snapshot_engine_decode_runtime_state, + snapshot_job_registry=owner._snapshot_engine_job_registry, + peek_queue_age_ms=owner.engine_stage_coordinator.peek_queue_age_ms, + merge_request_state_profile=owner._merge_request_state_profile, + ) + owner.engine_stage_coordinator.bind_arbiter( + notify_arbiter=owner._notify_engine_arbiter, + select_stage=owner._select_engine_stage, + mark_arbiter_tick=lambda stage, reason, policy_allowed: owner._mark_arbiter_tick( + stage=stage, + reason=reason, + policy_allowed=policy_allowed, + ), + wait_arbiter=owner.engine_policy_arbiter.wait, + ) + + def _init_facades(self) -> None: + owner = self.owner + owner.bridge_facade = EngineBridgeFacade(owner) + owner.api_facade = EngineApiFacade(owner) + owner.runtime_facade = EngineRuntimeFacade(owner) + + def _start_arbiter_thread(self) -> None: + owner = self.owner + owner.engine_arbiter_thread = threading.Thread( + target=owner._run_engine_arbiter_loop, + name="unified-engine-arbiter", + daemon=True, + ) + owner.engine_arbiter_thread.start() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py new file mode 100644 index 00000000..3a124f4e --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py @@ -0,0 +1,1150 @@ +from __future__ import annotations + +import asyncio +import os +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Deque, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState + + +@dataclass +class RuntimeControlCallbacks: + restart: Callable[[], None] | None = None + exit: Callable[[], None] | None = None + + +@dataclass +class DefaultReferenceState: + ref_audio_path: str | None = None + updated_at: float = 0.0 + + +class ReferenceRegistry: + def __init__(self) -> None: + self._lock = threading.Lock() + self._state = DefaultReferenceState() + + def set_default(self, ref_audio_path: str) -> DefaultReferenceState: + with self._lock: + self._state = DefaultReferenceState(ref_audio_path=str(ref_audio_path), updated_at=time.time()) + return self._state + + def clear(self) -> DefaultReferenceState: + with self._lock: + self._state = DefaultReferenceState() + return self._state + + def get_default(self) -> DefaultReferenceState: + with self._lock: + return DefaultReferenceState( + ref_audio_path=self._state.ref_audio_path, + updated_at=self._state.updated_at, + ) + + +@dataclass +class ModelRegistryState: + t2s_weights_path: str + vits_weights_path: str + generation: int = 0 + t2s_generation: int = 0 + vits_generation: int = 0 + updated_at: float = field(default_factory=time.time) + + +class ModelRegistry: + def __init__(self, t2s_weights_path: str, vits_weights_path: str) -> None: + self._lock = threading.Lock() + self._state = ModelRegistryState( + t2s_weights_path=str(t2s_weights_path), + vits_weights_path=str(vits_weights_path), + ) + + def snapshot(self) -> ModelRegistryState: + with self._lock: + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + def mark_t2s_reload(self, weights_path: str) -> ModelRegistryState: + with self._lock: + self._state.t2s_weights_path = str(weights_path) + self._state.generation += 1 + self._state.t2s_generation += 1 + self._state.updated_at = time.time() + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + def mark_vits_reload(self, weights_path: str) -> ModelRegistryState: + with self._lock: + self._state.vits_weights_path = str(weights_path) + self._state.generation += 1 + self._state.vits_generation += 1 + self._state.updated_at = time.time() + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + +@dataclass +class DirectTTSExecution: + media_type: str + streaming: bool + audio_generator: Optional[Generator[bytes, None, None]] = None + audio_bytes: Optional[bytes] = None + request_id: Optional[str] = None + + +@dataclass +class NormalizedEngineRequest: + request_id: str + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + aux_ref_audio_paths: List[str] | None = None + top_k: int = 15 + top_p: float = 1.0 + temperature: float = 1.0 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + ready_step: int = 0 + text_split_method: str = "cut5" + batch_size: int = 1 + batch_threshold: float = 0.75 + split_bucket: bool = False + speed_factor: float = 1.0 + fragment_interval: float = 0.3 + seed: int = -1 + media_type: str = "wav" + streaming_mode: bool | int = False + return_fragment: bool = False + fixed_length_chunk: bool = False + response_streaming: bool = False + parallel_infer: bool = False + sample_steps: int = 32 + super_sampling: bool = False + overlap_length: int = 2 + min_chunk_length: int = 16 + timeout_sec: float | None = None + + def to_payload(self) -> Dict[str, Any]: + return { + "request_id": self.request_id, + "text": self.text, + "text_lang": self.text_lang, + "ref_audio_path": self.ref_audio_path, + "aux_ref_audio_paths": list(self.aux_ref_audio_paths) if self.aux_ref_audio_paths else None, + "prompt_text": self.prompt_text, + "prompt_lang": self.prompt_lang, + "top_k": self.top_k, + "top_p": self.top_p, + "temperature": self.temperature, + "text_split_method": self.text_split_method, + "batch_size": self.batch_size, + "batch_threshold": self.batch_threshold, + "speed_factor": self.speed_factor, + "split_bucket": self.split_bucket, + "fragment_interval": self.fragment_interval, + "seed": self.seed, + "media_type": self.media_type, + "streaming_mode": self.streaming_mode, + "return_fragment": self.return_fragment, + "fixed_length_chunk": self.fixed_length_chunk, + "response_streaming": self.response_streaming, + "parallel_infer": self.parallel_infer, + "repetition_penalty": self.repetition_penalty, + "sample_steps": self.sample_steps, + "super_sampling": self.super_sampling, + "overlap_length": self.overlap_length, + "min_chunk_length": self.min_chunk_length, + "early_stop_num": self.early_stop_num, + "ready_step": self.ready_step, + "timeout_sec": self.timeout_sec, + } + + def to_scheduler_spec(self) -> SchedulerRequestSpec: + return SchedulerRequestSpec( + request_id=self.request_id, + ref_audio_path=Path(self.ref_audio_path), + prompt_text=self.prompt_text, + prompt_lang=self.prompt_lang, + text=self.text, + text_lang=self.text_lang, + top_k=self.top_k, + top_p=self.top_p, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty, + early_stop_num=self.early_stop_num, + ready_step=self.ready_step, + ) + + +@dataclass +class SchedulerDebugExecution: + payload: Dict[str, Any] + + +@dataclass +class SchedulerSubmitExecution: + audio_bytes: bytes + media_type: str + headers: Dict[str, str] + + +@dataclass +class EnginePolicyConfig: + enabled: bool = True + poll_wait_ms: float = 5.0 + decode_backlog_soft_max: int = 0 + finalize_pending_soft_max: int = 0 + prepare_inflight_soft_max: int = 0 + active_decode_soft_max: int = 0 + ready_for_prefill_soft_max: int = 0 + active_request_soft_max: int = 0 + + def to_dict(self) -> Dict[str, Any]: + return { + "enabled": bool(self.enabled), + "poll_wait_ms": float(self.poll_wait_ms), + "decode_backlog_soft_max": int(self.decode_backlog_soft_max), + "finalize_pending_soft_max": int(self.finalize_pending_soft_max), + "prepare_inflight_soft_max": int(self.prepare_inflight_soft_max), + "active_decode_soft_max": int(self.active_decode_soft_max), + "ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max), + "active_request_soft_max": int(self.active_request_soft_max), + } + + +@dataclass +class EngineArbiterConfig: + poll_wait_ms: float = 5.0 + decode_burst: int = 4 + prepare_aging_ms: float = 10.0 + finalize_aging_ms: float = 10.0 + + def to_dict(self) -> Dict[str, Any]: + return { + "poll_wait_ms": float(self.poll_wait_ms), + "decode_burst": int(self.decode_burst), + "prepare_aging_ms": float(self.prepare_aging_ms), + "finalize_aging_ms": float(self.finalize_aging_ms), + } + + +class EngineStatus: + NEW = "NEW" + QUEUED = "QUEUED" + VALIDATED = "VALIDATED" + CPU_PREPARING = "CPU_PREPARING" + GPU_PREPARING = "GPU_PREPARING" + READY_FOR_PREFILL = "READY_FOR_PREFILL" + ACTIVE_DECODE = "ACTIVE_DECODE" + READY_FOR_FINALIZE = "READY_FOR_FINALIZE" + FINALIZING = "FINALIZING" + STREAMING = "STREAMING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + +@dataclass +class EngineRequestState: + request_id: str + api_mode: str + backend: str + media_type: str + response_streaming: bool + submit_ts: float + deadline_ts: float | None = None + status: str = EngineStatus.NEW + updated_ts: float = 0.0 + error: str | None = None + finish_reason: str | None = None + meta: Dict[str, Any] = field(default_factory=dict) + profile: Dict[str, Any] = field(default_factory=dict) + lifecycle_timestamps: Dict[str, float] = field(default_factory=dict) + + def to_summary(self) -> Dict[str, Any]: + return { + "request_id": self.request_id, + "api_mode": self.api_mode, + "backend": self.backend, + "media_type": self.media_type, + "response_streaming": self.response_streaming, + "status": self.status, + "submit_ts": self.submit_ts, + "updated_ts": self.updated_ts, + "deadline_ts": self.deadline_ts, + "error": self.error, + "finish_reason": self.finish_reason, + "meta": dict(self.meta), + "profile": dict(self.profile), + "lifecycle_timestamps": dict(self.lifecycle_timestamps), + } + + +class EngineRequestRegistry: + def __init__(self, recent_limit: int) -> None: + self.lock = threading.Lock() + self.active_requests: Dict[str, EngineRequestState] = {} + self.recent_requests: Deque[EngineRequestState] = deque() + self.recent_limit = max(1, int(recent_limit)) + + def register( + self, + *, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + now = time.perf_counter() + state = EngineRequestState( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=bool(response_streaming), + submit_ts=now, + deadline_ts=deadline_ts, + updated_ts=now, + meta=dict(meta or {}), + lifecycle_timestamps={EngineStatus.NEW: now}, + ) + with self.lock: + self.active_requests[request_id] = state + return state + + def _move_to_recent_locked(self, state: EngineRequestState) -> None: + self.recent_requests.appendleft(state) + while len(self.recent_requests) > self.recent_limit: + self.recent_requests.pop() + + @staticmethod + def _apply_state_extra(state: EngineRequestState, extra: Optional[Dict[str, Any]]) -> None: + if not extra: + return + payload = dict(extra) + backend = payload.pop("backend", None) + if backend is not None: + state.backend = str(backend) + finish_reason = payload.pop("finish_reason", None) + if finish_reason is not None: + state.finish_reason = str(finish_reason) + error = payload.pop("error", None) + if error is not None: + state.error = str(error) + state.profile.update(payload) + + def update(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.get(request_id) + if state is None: + return + state.status = str(status) + state.updated_ts = now + state.lifecycle_timestamps[str(status)] = now + self._apply_state_extra(state, extra) + + def merge_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + if not extra: + return + now = time.perf_counter() + with self.lock: + state = self.active_requests.get(request_id) + if state is None: + for recent_state in self.recent_requests: + if recent_state.request_id == request_id: + state = recent_state + break + if state is None: + return + state.updated_ts = now + self._apply_state_extra(state, extra) + + def complete(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.COMPLETED + state.updated_ts = now + state.lifecycle_timestamps[EngineStatus.COMPLETED] = now + self._apply_state_extra(state, extra) + self._move_to_recent_locked(state) + + def fail(self, request_id: str, error: str) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.FAILED + state.updated_ts = now + state.error = str(error) + state.lifecycle_timestamps[EngineStatus.FAILED] = now + self._move_to_recent_locked(state) + + def snapshot(self) -> Dict[str, Any]: + with self.lock: + active = [state.to_summary() for state in self.active_requests.values()] + recent = [state.to_summary() for state in list(self.recent_requests)] + recent_limit = self.recent_limit + active.sort(key=lambda item: item["submit_ts"]) + return { + "active_count": len(active), + "recent_count": len(recent), + "recent_limit": recent_limit, + "active_requests": active, + "recent_requests": recent, + } + + def collect_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: + requested = set(request_ids) + results: List[Dict[str, Any]] = [] + with self.lock: + for state in self.active_requests.values(): + if state.request_id in requested: + results.append(state.to_summary()) + existing_ids = {item["request_id"] for item in results} + for state in self.recent_requests: + if state.request_id in requested and state.request_id not in existing_ids: + results.append(state.to_summary()) + results.sort(key=lambda item: item["request_id"]) + return results + + def has_active(self, request_id: str) -> bool: + with self.lock: + return request_id in self.active_requests + + +@dataclass +class SchedulerPendingJob: + request_id: str + state: T2SRequestState + done_event: threading.Event + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + enqueue_time: float + speed_factor: float + sample_steps: int + media_type: str + admission_wait_ms: float = 0.0 + engine_policy_wait_ms: float = 0.0 + engine_dispatch_wait_ms: float = 0.0 + prepare_wall_ms: float = 0.0 + prepare_profile_total_ms: float = 0.0 + first_schedule_time: float | None = None + prefill_ms: float = 0.0 + merge_ms: float = 0.0 + decode_ms: float = 0.0 + finalize_wait_ms: float = 0.0 + synth_ms: float = 0.0 + pack_ms: float = 0.0 + decode_steps: int = 0 + result_ready_time: float | None = None + result: dict | None = None + sample_rate: int | None = None + audio_data: np.ndarray | None = None + error: str | None = None + engine_request_id: str | None = None + + +class SchedulerJobRegistry: + def __init__(self, lock: threading.Lock | threading.RLock | threading.Condition) -> None: + self._lock = lock + self._job_map: Dict[str, SchedulerPendingJob] = {} + self._total_submitted = 0 + self._total_finished = 0 + + def register(self, job: SchedulerPendingJob, *, keep_job: bool = True) -> None: + with self._lock: + if keep_job: + self._job_map[job.request_id] = job + self._total_submitted += 1 + + def get(self, request_id: str) -> SchedulerPendingJob | None: + with self._lock: + return self._job_map.get(request_id) + + def pop(self, request_id: str) -> SchedulerPendingJob | None: + with self._lock: + return self._job_map.pop(request_id, None) + + def remove(self, request_id: str) -> None: + with self._lock: + self._job_map.pop(request_id, None) + + def mark_finished(self) -> None: + with self._lock: + self._total_finished += 1 + + def mark_finished_and_remove(self, request_id: str) -> None: + with self._lock: + self._job_map.pop(request_id, None) + self._total_finished += 1 + + def is_empty(self) -> bool: + with self._lock: + return not self._job_map + + def submitted_count(self) -> int: + with self._lock: + return int(self._total_submitted) + + def finished_count(self) -> int: + with self._lock: + return int(self._total_finished) + + def snapshot(self, max_request_ids: int = 32) -> Dict[str, Any]: + with self._lock: + request_ids = list(self._job_map.keys()) + return { + "job_count": int(len(request_ids)), + "request_ids": request_ids[: max(0, int(max_request_ids))], + "total_submitted": int(self._total_submitted), + "total_finished": int(self._total_finished), + } + + +class EngineTaskQueueOwner: + def __init__(self, completion_key: str = "total_completed") -> None: + self.condition = threading.Condition() + self.queue: Deque[Any] = deque() + self.total_submitted = 0 + self.total_completed = 0 + self.peak_waiting = 0 + self.completion_key = str(completion_key) + + def enqueue(self, item: Any) -> None: + with self.condition: + self.queue.append(item) + self.total_submitted += 1 + self.peak_waiting = max(self.peak_waiting, len(self.queue)) + self.condition.notify_all() + + def enqueue_many(self, items: Sequence[Any]) -> None: + if not items: + return + with self.condition: + for item in items: + self.queue.append(item) + self.total_submitted += len(items) + self.peak_waiting = max(self.peak_waiting, len(self.queue)) + self.condition.notify_all() + + def pop_left(self) -> Any | None: + with self.condition: + if not self.queue: + return None + return self.queue.popleft() + + def mark_completed(self, count: int = 1, *, notify: bool = False) -> None: + if count <= 0: + return + with self.condition: + self.total_completed += int(count) + if notify: + self.condition.notify_all() + + def has_items(self) -> bool: + with self.condition: + return bool(self.queue) + + def waiting_count(self) -> int: + with self.condition: + return int(len(self.queue)) + + def snapshot(self, *, max_request_ids: int = 16, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + with self.condition: + waiting_items = list(self.queue)[: max(0, int(max_request_ids))] + snapshot = { + "waiting_count": int(len(self.queue)), + "waiting_request_ids": [str(getattr(item, "request_id", "")) for item in waiting_items], + "peak_waiting": int(self.peak_waiting), + "total_submitted": int(self.total_submitted), + self.completion_key: int(self.total_completed), + } + if extra: + snapshot.update(dict(extra)) + return snapshot + + def peek_oldest_age_ms(self, timestamp_attr: str) -> float: + with self.condition: + if not self.queue: + return 0.0 + enqueue_time = float(getattr(self.queue[0], timestamp_attr)) + return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) + + def is_drained(self) -> bool: + with self.condition: + return not self.queue and self.total_submitted == self.total_completed + + def take_finalize_batch( + self, + *, + finalize_mode: str, + batch_max_items: int, + batch_wait_s: float, + use_vocoder: bool, + ) -> List[SchedulerFinalizeTask]: + with self.condition: + if not self.queue: + return [] + selected_tasks = [self.queue.popleft()] + if finalize_mode == "sync" or use_vocoder: + return selected_tasks + if batch_max_items <= 1: + return selected_tasks + first_task = selected_tasks[0] + oldest_age_s = max(0.0, time.perf_counter() - first_task.enqueued_time) + if len(self.queue) + 1 < batch_max_items and oldest_age_s < batch_wait_s: + self.queue.appendleft(first_task) + return [] + while len(selected_tasks) < batch_max_items: + if not self.queue: + break + matched_index = None + for index, task in enumerate(self.queue): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is None: + break + selected_tasks.append(self.queue[matched_index]) + del self.queue[matched_index] + return selected_tasks + + +class EnginePolicyArbiterController: + def __init__( + self, + *, + policy_config: EnginePolicyConfig, + arbiter_config: EngineArbiterConfig, + snapshot_request_registry: Callable[[], Dict[str, Any]], + get_worker_state: Callable[[], Dict[str, Any]], + snapshot_prepare_state: Callable[[], Dict[str, Any]], + snapshot_finalize_state: Callable[[], Dict[str, Any]], + snapshot_dispatch_state: Callable[[], Dict[str, Any]], + snapshot_decode_runtime_state: Callable[[], Dict[str, Any]], + snapshot_job_registry: Callable[[], Dict[str, Any]], + peek_queue_age_ms: Callable[[str], float], + merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], + ) -> None: + self.policy_config = policy_config + self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0) + self.arbiter_config = arbiter_config + self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0) + self.condition = threading.Condition() + self.state = EngineArbiterState( + decode_budget_remaining=int(self.arbiter_config.decode_burst), + last_observed_at=time.perf_counter(), + ) + self.snapshot_request_registry = snapshot_request_registry + self.get_worker_state = get_worker_state + self.snapshot_prepare_state = snapshot_prepare_state + self.snapshot_finalize_state = snapshot_finalize_state + self.snapshot_dispatch_state = snapshot_dispatch_state + self.snapshot_decode_runtime_state = snapshot_decode_runtime_state + self.snapshot_job_registry = snapshot_job_registry + self.peek_queue_age_ms = peek_queue_age_ms + self.merge_request_state_profile = merge_request_state_profile + + def snapshot_state(self) -> Dict[str, Any]: + with self.condition: + return { + "config": self.arbiter_config.to_dict(), + "total_ticks": int(self.state.total_ticks), + "total_idle_ticks": int(self.state.total_idle_ticks), + "total_prepare_dispatches": int(self.state.total_prepare_dispatches), + "total_decode_dispatches": int(self.state.total_decode_dispatches), + "total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks), + "total_finalize_dispatches": int(self.state.total_finalize_dispatches), + "decode_budget_remaining": int(self.state.decode_budget_remaining), + "last_stage": str(self.state.last_stage), + "last_reason": str(self.state.last_reason), + "last_policy_allowed": bool(self.state.last_policy_allowed), + "last_observed_at": float(self.state.last_observed_at), + } + + def notify(self) -> None: + with self.condition: + self.condition.notify_all() + + def wait(self) -> None: + with self.condition: + self.condition.wait(timeout=self.arbiter_poll_s) + + def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + with self.condition: + self.state.total_ticks += 1 + if stage == "idle": + self.state.total_idle_ticks += 1 + elif stage == "prepare": + self.state.total_prepare_dispatches += 1 + self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) + elif stage == "finalize": + self.state.total_finalize_dispatches += 1 + self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) + elif stage == "decode_dispatch": + self.state.total_decode_dispatches += 1 + elif stage == "decode_runtime": + self.state.total_decode_runtime_ticks += 1 + self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1) + self.state.last_stage = str(stage) + self.state.last_reason = str(reason) + self.state.last_policy_allowed = bool(policy_allowed) + self.state.last_observed_at = time.perf_counter() + + def build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + prepare_dispatcher_state = self.snapshot_prepare_state() + finalize_dispatcher_state = self.snapshot_finalize_state() + dispatcher_state = self.snapshot_dispatch_state() + active_requests = list(request_registry.get("active_requests", [])) + status_counts: Dict[str, int] = {} + for item in active_requests: + status = str(item.get("status", "UNKNOWN")) + status_counts[status] = status_counts.get(status, 0) + 1 + + worker_pending_jobs = int(worker_state.get("pending_jobs", 0)) + worker_decode_active_size = int(worker_state.get("running_requests", 0)) + worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0)) + worker_finalize_pending = int(worker_state.get("finalize_pending", 0)) + worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0)) + engine_decode_runtime_state = self.snapshot_decode_runtime_state() + engine_job_registry = self.snapshot_job_registry() + decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0)) + decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0)) + return { + "active_request_count": int(len(active_requests)), + "status_counts": status_counts, + "queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)), + "cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)), + "gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)), + "ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)), + "active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)), + "ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)), + "finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)), + "streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)), + "worker_pending_jobs": worker_pending_jobs, + "worker_decode_active_size": worker_decode_active_size, + "worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)), + "worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)), + "engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs, + "engine_decode_runtime_active_request_count": decode_runtime_active_size, + "engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)), + "engine_job_registry_count": int(engine_job_registry.get("job_count", 0)), + "worker_prepare_inflight": worker_prepare_inflight, + "worker_finalize_pending": worker_finalize_pending, + "worker_finalize_inflight": worker_finalize_inflight, + "engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)), + "engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)), + "engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)), + "decode_backlog": int( + decode_runtime_pending_jobs + decode_runtime_active_size + if bool(worker_state.get("engine_decode_control_enabled", False)) + else worker_pending_jobs + worker_decode_active_size + ), + } + + def build_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + counters = self.build_stage_counters(request_registry, worker_state) + config = self.policy_config.to_dict() + blocked_reasons: List[Dict[str, Any]] = [] + finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0)) + limit_checks = [ + ("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])), + ("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])), + ("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])), + ("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])), + ("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])), + ("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])), + ] + if bool(config["enabled"]): + for name, value, limit in limit_checks: + if limit > 0 and int(value) >= int(limit): + blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)}) + return { + "enabled": bool(config["enabled"]), + "allowed": (not bool(config["enabled"])) or not blocked_reasons, + "blocked_reasons": blocked_reasons, + "config": config, + "metrics": { + "active_request_count": int(counters["active_request_count"]), + "queued_request_count": int(counters["queued_request_count"]), + "ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]), + "active_decode_request_count": int(counters["active_decode_request_count"]), + "engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]), + "engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]), + "decode_backlog": int(counters["decode_backlog"]), + "prepare_inflight": int(counters["worker_prepare_inflight"]), + "finalize_pending": int(finalize_pending_total), + "engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)), + "finalize_inflight": int(counters["worker_finalize_inflight"]), + }, + "observed_at": time.perf_counter(), + } + + async def wait_for_policy_admission( + self, + *, + request_id: str | None, + timeout_sec: float | None, + ) -> tuple[float, Dict[str, Any]]: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + snapshot = self.build_policy_snapshot(request_registry, worker_state) + if not self.policy_config.enabled: + return 0.0, snapshot + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + while True: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + snapshot = self.build_policy_snapshot(request_registry, worker_state) + if snapshot["allowed"]: + wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0) + if request_id not in [None, ""]: + self.merge_request_state_profile( + str(request_id), + { + "engine_policy_wait_ms": float(wait_ms), + "engine_policy_snapshot": snapshot, + }, + ) + return wait_ms, snapshot + now = time.perf_counter() + if deadline is not None and now >= deadline: + blocked_summary = ", ".join( + f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", []) + ) + raise TimeoutError(f"engine policy admission timeout ({blocked_summary})") + await asyncio.sleep(self.policy_poll_s) + + def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) + prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0)) + finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) + decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) + decode_runtime_state = self.snapshot_decode_runtime_state() + worker_decode_has_work = bool(decode_runtime_state.get("has_work", False)) + worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False)) + worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) + worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) + prepare_age_ms = float(self.peek_queue_age_ms("prepare")) + finalize_age_ms = float(self.peek_queue_age_ms("finalize")) + decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) + decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) + policy_allowed = bool(policy_snapshot.get("allowed", True)) + if ( + worker_decode_control_enabled + and worker_decode_has_work + and policy_allowed + and decode_budget_remaining > 0 + and (worker_running_requests > 0 or worker_pending_jobs > 0) + ): + return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state + if ( + worker_decode_control_enabled + and worker_pending_jobs > 0 + and policy_allowed + and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms) + ): + return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state + if ( + decode_waiting > 0 + and policy_allowed + and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) + ): + return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state + if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): + return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): + return "finalize", "finalize_aging", policy_snapshot, worker_state + if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): + return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): + return "prepare", "prepare_aging", policy_snapshot, worker_state + if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: + return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state + if decode_waiting > 0 and policy_allowed: + return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state + if finalize_waiting > 0: + return "finalize", "finalize_fallback", policy_snapshot, worker_state + if prepare_waiting > 0: + return "prepare", "prepare_fallback", policy_snapshot, worker_state + return "idle", "no_pending_work", policy_snapshot, worker_state + + +class EngineDecodeRuntimeOwner: + def __init__( + self, + *, + get_decode_runtime_counters: Callable[[], Dict[str, int]], + get_micro_batch_wait_s: Callable[[], float], + ) -> None: + self.get_decode_runtime_counters = get_decode_runtime_counters + self.get_micro_batch_wait_s = get_micro_batch_wait_s + self.condition = threading.Condition() + self.pending_jobs: Deque[SchedulerPendingJob] = deque() + self.active_batch: T2SActiveBatch | None = None + self.state_lock = threading.Lock() + self.state = EngineDecodeRuntimeState(updated_at=time.perf_counter()) + + @staticmethod + def summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + if active_batch is None: + return {} + decode_step_index_max = 0 + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0: + decode_step_index_max = int(active_batch.step_indices.max().item()) + return { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": int(decode_step_index_max), + } + + def snapshot_pending_queue_state(self) -> Dict[str, Any]: + with self.condition: + return { + "pending_jobs": int(len(self.pending_jobs)), + "pending_request_ids": [job.request_id for job in list(self.pending_jobs)[:32]], + } + + def enqueue_pending_job(self, job: SchedulerPendingJob) -> None: + with self.condition: + self.pending_jobs.append(job) + self.condition.notify_all() + self.refresh_state("engine_decode_pending_enqueue") + + def take_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs: + return [] + if wait_for_batch: + oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) + if (time.perf_counter() - oldest_enqueue_time) < self.get_micro_batch_wait_s(): + return [] + pending_jobs = list(self.pending_jobs) + self.pending_jobs.clear() + self.refresh_state("engine_decode_pending_dequeue") + return pending_jobs + + def pending_age_ms(self) -> float: + with self.condition: + if not self.pending_jobs: + return 0.0 + enqueue_time = float(self.pending_jobs[0].enqueue_time) + return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) + + def has_pending_jobs(self) -> bool: + with self.condition: + return bool(self.pending_jobs) + + def get_active_batch(self) -> T2SActiveBatch | None: + return self.active_batch + + def set_active_batch(self, active_batch: T2SActiveBatch | None) -> None: + self.active_batch = active_batch + + def active_batch_summary(self) -> Dict[str, Any]: + return self.summarize_active_batch(self.active_batch) + + def refresh_state(self, last_event: str) -> None: + pending_state = self.snapshot_pending_queue_state() + active_batch_summary = self.active_batch_summary() + worker_decode_counters = self.get_decode_runtime_counters() + with self.state_lock: + self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) + self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) + self.state.active_request_count = int(active_batch_summary.get("request_count", 0)) + self.state.active_request_ids = list(active_batch_summary.get("request_ids", []))[:32] + self.state.prefill_done = bool(active_batch_summary.get("prefill_done", False)) + self.state.decode_step_index_max = int(active_batch_summary.get("decode_step_index_max", 0)) + self.state.total_cycles = int(worker_decode_counters.get("total_cycles", 0)) + self.state.prefill_cycles = int(worker_decode_counters.get("prefill_cycles", 0)) + self.state.step_cycles = int(worker_decode_counters.get("step_cycles", 0)) + self.state.has_work = bool(pending_state.get("pending_jobs", 0) or active_batch_summary.get("request_count", 0)) + self.state.last_event = str(last_event) + self.state.updated_at = float(time.perf_counter()) + + def update_from_worker_snapshot(self, snapshot: Dict[str, Any]) -> None: + if not snapshot: + return + pending_state = self.snapshot_pending_queue_state() + with self.state_lock: + self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) + self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) + self.state.active_request_count = int(snapshot.get("active_request_count", 0)) + self.state.active_request_ids = list(snapshot.get("active_request_ids", []))[:32] + self.state.prefill_done = bool(snapshot.get("prefill_done", False)) + self.state.decode_step_index_max = int(snapshot.get("decode_step_index_max", 0)) + self.state.total_cycles = int(snapshot.get("total_cycles", 0)) + self.state.prefill_cycles = int(snapshot.get("prefill_cycles", 0)) + self.state.step_cycles = int(snapshot.get("step_cycles", 0)) + self.state.has_work = bool( + pending_state.get("pending_jobs", 0) + or snapshot.get("active_request_count", 0) + or snapshot.get("has_work", False) + ) + self.state.last_event = str(snapshot.get("last_event", "unknown")) + self.state.updated_at = float(snapshot.get("updated_at", time.perf_counter())) + + def snapshot_state(self) -> Dict[str, Any]: + pending_state = self.snapshot_pending_queue_state() + active_batch_summary = self.active_batch_summary() + worker_decode_counters = self.get_decode_runtime_counters() + with self.state_lock: + return { + "pending_jobs": int(pending_state.get("pending_jobs", self.state.pending_jobs)), + "pending_request_ids": list(pending_state.get("pending_request_ids", self.state.pending_request_ids)), + "active_request_count": int(active_batch_summary.get("request_count", self.state.active_request_count)), + "active_request_ids": list(active_batch_summary.get("request_ids", self.state.active_request_ids)), + "prefill_done": bool(active_batch_summary.get("prefill_done", self.state.prefill_done)), + "decode_step_index_max": int( + active_batch_summary.get("decode_step_index_max", self.state.decode_step_index_max) + ), + "total_cycles": int(worker_decode_counters.get("total_cycles", 0)), + "prefill_cycles": int(worker_decode_counters.get("prefill_cycles", 0)), + "step_cycles": int(worker_decode_counters.get("step_cycles", 0)), + "has_work": bool( + pending_state.get("pending_jobs", 0) + or active_batch_summary.get("request_count", self.state.active_request_count) + or self.state.has_work + ), + "last_event": str(self.state.last_event), + "updated_at": float(self.state.updated_at), + } + +@dataclass +class SchedulerFinalizeTask: + request_id: str + item: T2SFinishedItem + enqueued_time: float + + +@dataclass +class EngineDispatchTask: + request_id: str + state: T2SRequestState + speed_factor: float + sample_steps: int + media_type: str + prepare_wall_ms: float + prepare_profile_total_ms: float + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + engine_request_id: str | None + timeout_sec: float | None + enqueue_time: float + worker_job: SchedulerPendingJob | None = None + engine_policy_wait_ms: float = 0.0 + engine_dispatch_wait_ms: float = 0.0 + engine_policy_snapshot: Dict[str, Any] | None = None + error: str | None = None + + +@dataclass +class EngineGpuPrepareTask: + request_id: str + cpu_stage: PreparedCpuStage + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + engine_request_id: str | None + enqueue_time: float + queue_wait_ms: float = 0.0 + error: str | None = None + + +@dataclass +class EngineFinalizeQueueState: + waiting_count: int + waiting_request_ids: List[str] + peak_waiting: int + total_submitted: int + total_completed: int + + +@dataclass +class EngineArbiterState: + total_ticks: int = 0 + total_idle_ticks: int = 0 + total_prepare_dispatches: int = 0 + total_decode_dispatches: int = 0 + total_decode_runtime_ticks: int = 0 + total_finalize_dispatches: int = 0 + decode_budget_remaining: int = 0 + last_stage: str = "idle" + last_reason: str = "init" + last_observed_at: float = 0.0 + last_policy_allowed: bool = True + + +@dataclass +class EngineDecodeRuntimeState: + pending_jobs: int = 0 + pending_request_ids: List[str] = field(default_factory=list) + active_request_count: int = 0 + active_request_ids: List[str] = field(default_factory=list) + prefill_done: bool = False + decode_step_index_max: int = 0 + total_cycles: int = 0 + prefill_cycles: int = 0 + step_cycles: int = 0 + has_work: bool = False + last_event: str = "init" + updated_at: float = 0.0 + + +@dataclass +class RuntimeStateCallbacks: + update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None + complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None + fail: Callable[[str, str], None] | None = None + decode_runtime_update: Callable[[Dict[str, Any]], None] | None = None + + diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py new file mode 100644 index 00000000..7dbbd5bd --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py @@ -0,0 +1,446 @@ +from __future__ import annotations + +import asyncio +from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, EngineDispatchTask, EngineRequestState, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerFinalizeTask, SchedulerPendingJob, SchedulerSubmitExecution +from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade + + +class EngineBridgeDelegates: + def _register_request_state( + self, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + return self.bridge_facade._register_request_state( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=response_streaming, + deadline_ts=deadline_ts, + meta=meta, + ) + + def _update_request_state(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._update_request_state(request_id, status, extra) + + def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._merge_request_state_profile(request_id, extra) + + def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_prepare_state() + + def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_finalize_state() + + def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_dispatch_state() + + def _register_engine_job(self, job: SchedulerPendingJob) -> None: + self.bridge_facade._register_engine_job(job) + + def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.bridge_facade._get_engine_job(request_id) + + def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.bridge_facade._pop_engine_job(request_id) + + def _snapshot_engine_job_registry(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_job_registry() + + def _is_engine_drained(self) -> bool: + return self.bridge_facade._is_engine_drained() + + def _record_engine_job_done(self, request_id: str) -> None: + self.bridge_facade._record_engine_job_done(request_id) + + def _complete_engine_job( + self, + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + self.bridge_facade._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) + + def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: + self.bridge_facade._fail_engine_jobs(request_ids, error) + + def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + self.bridge_facade._add_engine_prefill_time(jobs, elapsed_s) + + def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + self.bridge_facade._add_engine_merge_time(request_ids, elapsed_s) + + def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + self.bridge_facade._add_engine_decode_time(request_ids, elapsed_s) + + def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: + self.bridge_facade._enqueue_engine_finished_items(items) + + def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_decode_pending_queue_state() + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + return EngineBridgeFacade._summarize_active_batch(active_batch) + + def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: + self.bridge_facade._refresh_engine_decode_runtime_state(last_event) + + def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: + self.bridge_facade._update_engine_decode_runtime_state(snapshot) + + def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_decode_runtime_state() + + def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_arbiter_state() + + def _notify_engine_arbiter(self) -> None: + self.bridge_facade._notify_engine_arbiter() + + def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: + self.bridge_facade._enqueue_engine_decode_pending_job(job) + + def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.bridge_facade._take_engine_decode_pending_jobs_nonblocking(wait_for_batch) + + def _peek_queue_age_ms(self, queue_name: str) -> float: + return self.bridge_facade._peek_queue_age_ms(queue_name) + + def _engine_has_pending_work(self) -> bool: + return self.bridge_facade._engine_has_pending_work() + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + return await self.bridge_facade._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, + ) + + def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + self.bridge_facade._enqueue_worker_finished_for_finalize(tasks) + + def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + return self.bridge_facade._take_engine_finalize_batch_nonblocking() + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + return await self.bridge_facade._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=speed_factor, + sample_steps=sample_steps, + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id, + timeout_sec=timeout_sec, + ) + + def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + self.bridge_facade._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) + + def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + return self.bridge_facade._select_engine_stage() + + def _run_engine_prepare_once(self) -> bool: + return self.bridge_facade._run_engine_prepare_once() + + def _run_engine_finalize_once(self) -> bool: + return self.bridge_facade._run_engine_finalize_once() + + def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + return self.bridge_facade._run_engine_dispatch_once(policy_snapshot, worker_state) + + def _run_engine_decode_runtime_once(self) -> bool: + return self.bridge_facade._run_engine_decode_runtime_once() + + def _run_engine_arbiter_loop(self) -> None: + self.bridge_facade._run_engine_arbiter_loop() + + def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._complete_request_state(request_id, extra) + + def _fail_request_state(self, request_id: str, error: str) -> None: + self.bridge_facade._fail_request_state(request_id, error) + + def _snapshot_request_registry(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_request_registry() + + +class EngineApiDelegates: + def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: + return self.api_facade._collect_request_summaries(request_ids) + + def _has_active_request(self, request_id: str) -> bool: + return self.api_facade._has_active_request(request_id) + + @staticmethod + def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: + return EngineApiFacade._build_request_meta(payload) + + @staticmethod + def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: + return EngineApiFacade._sum_profile_field(items, key) + + def _build_direct_segment_trace( + self, + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + return self.api_facade._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) + + def _build_direct_scheduler_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_direct_scheduler_profile(**kwargs) + + def _build_legacy_direct_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_legacy_direct_profile(**kwargs) + + def _build_scheduler_submit_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_scheduler_submit_profile(**kwargs) + + @staticmethod + def _format_ms_header(value: Any) -> str: + return EngineApiFacade._format_ms_header(value) + + def _build_scheduler_submit_headers( + self, + *, + request_id: str, + media_type: str, + sample_rate: int, + profile: Dict[str, Any], + ) -> Dict[str, str]: + return self.api_facade._build_scheduler_submit_headers( + request_id=request_id, + media_type=media_type, + sample_rate=sample_rate, + profile=profile, + ) + + def _build_scheduler_debug_request_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_scheduler_debug_request_profile(**kwargs) + + @staticmethod + def _build_scheduler_debug_batch_profile(**kwargs: Any) -> Dict[str, Any]: + return EngineApiFacade._build_scheduler_debug_batch_profile(**kwargs) + + def _normalize_lang(self, value: str | None) -> str | None: + return self.api_facade._normalize_lang(value) + + @staticmethod + def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: + return EngineApiFacade._aggregate_numeric_dicts(items) + + def _apply_default_reference(self, req: dict) -> dict: + return self.api_facade._apply_default_reference(req) + + def check_params(self, req: dict) -> Optional[str]: + return self.api_facade.check_params(req) + + @staticmethod + def _base_request_defaults() -> Dict[str, Any]: + return EngineApiFacade._base_request_defaults() + + def _normalize_engine_request( + self, + payload: dict | NormalizedEngineRequest, + *, + request_id: str | None = None, + normalize_streaming: bool = False, + error_prefix: str = "request 参数非法: ", + ) -> NormalizedEngineRequest: + return self.api_facade._normalize_engine_request( + payload, + request_id=request_id, + normalize_streaming=normalize_streaming, + error_prefix=error_prefix, + ) + + @staticmethod + def _normalize_streaming_mode(req: dict) -> dict: + return EngineApiFacade._normalize_streaming_mode(req) + + @staticmethod + def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: + return EngineApiFacade._is_aux_ref_enabled(aux_ref_audio_paths) + + def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: + return self.api_facade._select_direct_backend(normalized) + + def _iter_legacy_direct_tts_bytes( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> Generator[bytes, None, None]: + return self.api_facade._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: + return self.api_facade._should_use_scheduler_backend_for_direct(req) + + def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: + return self.api_facade._segment_direct_text(normalized) + + def _build_segment_request( + self, + normalized: NormalizedEngineRequest, + *, + request_id: str, + text: str, + ) -> NormalizedEngineRequest: + return self.api_facade._build_segment_request(normalized, request_id=request_id, text=text) + + async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: + return await self.api_facade._run_direct_tts_via_scheduler(normalized) + + def _run_legacy_direct_tts_blocking( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + return self.api_facade._run_legacy_direct_tts_blocking( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + async def _run_direct_tts_via_legacy_backend( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + return await self.api_facade._run_direct_tts_via_legacy_backend( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: + return await self.api_facade.run_direct_tts_async(req) + + def run_direct_tts(self, req: dict) -> DirectTTSExecution: + return self.api_facade.run_direct_tts(req) + + def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: + return self.api_facade.build_scheduler_request_specs(request_items) + + def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: + return self.api_facade.build_scheduler_submit_spec(payload) + + @staticmethod + def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: + return EngineApiFacade.summarize_scheduler_states(states) + + @staticmethod + def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: + return EngineApiFacade.summarize_scheduler_finished(items) + + async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: + return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed) + + async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: + return await self.api_facade.run_scheduler_submit(payload) + + +class EngineRuntimeDelegates: + @staticmethod + def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: + return EngineRuntimeFacade._safe_component_snapshot(component) + + def _build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_stage_counters(request_registry, worker_state) + + def _build_engine_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_engine_policy_snapshot(request_registry, worker_state) + + async def _wait_for_engine_policy_admission( + self, + *, + request_id: str | None, + timeout_sec: float | None, + ) -> tuple[float, Dict[str, Any]]: + return await self.engine_policy_arbiter.wait_for_policy_admission( + request_id=request_id, + timeout_sec=timeout_sec, + ) + + def _build_stage_summary( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_stage_summary(request_registry, worker_state) + + def get_scheduler_state(self) -> dict: + return self.runtime_facade.get_scheduler_state() + + def get_runtime_state(self) -> dict: + return self.runtime_facade.get_runtime_state() + + def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: + self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec) + + def set_refer_audio(self, refer_audio_path: str | None) -> dict: + return self.runtime_facade.set_refer_audio(refer_audio_path) + + def set_gpt_weights(self, weights_path: str) -> dict: + return self.runtime_facade.set_gpt_weights(weights_path) + + def set_sovits_weights(self, weights_path: str) -> dict: + return self.runtime_facade.set_sovits_weights(weights_path) + + def handle_control(self, command: str) -> None: + self.runtime_facade.handle_control(command) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime.py new file mode 100644 index 00000000..70212a4d --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import os +import signal +import sys +from typing import Any, Dict, Optional + + +class EngineRuntimeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def tts(self): + return self.owner.tts + + @property + def reference_registry(self): + return self.owner.reference_registry + + @property + def model_registry(self): + return self.owner.model_registry + + @property + def scheduler_worker(self): + return self.owner.scheduler_worker + + @property + def engine_decode_runtime_owner(self): + return self.owner.engine_decode_runtime_owner + + @property + def engine_policy_arbiter(self): + return self.owner.engine_policy_arbiter + + @property + def management_lock(self): + return self.owner.management_lock + + @property + def direct_tts_lock(self): + return self.owner.direct_tts_lock + + @property + def control_callbacks(self): + return self.owner.control_callbacks + + @staticmethod + def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: + if component is None or not hasattr(component, "snapshot"): + return None + try: + return dict(component.snapshot()) + except Exception: + return None + + def _build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.engine_policy_arbiter.build_stage_counters(request_registry, worker_state) + + def _build_engine_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.engine_policy_arbiter.build_policy_snapshot(request_registry, worker_state) + + def _build_stage_summary( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + counters = self._build_stage_counters(request_registry, worker_state) + bert_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_bert_batch_worker", None)) + ref_semantic_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_ref_semantic_batch_worker", None)) + text_preprocessor_state = self._safe_component_snapshot(getattr(self.tts, "text_preprocessor", None)) + + return { + **counters, + "engine_drained": bool(self.owner._is_engine_drained()), + "admission_config": { + "decode_backlog_max": int(worker_state.get("decode_backlog_max", 0)), + "finalize_pending_max": int(worker_state.get("finalize_pending_max", 0)), + }, + "engine_policy": self._build_engine_policy_snapshot(request_registry, worker_state), + "engine_arbiter_state": self.owner._snapshot_engine_arbiter_state(), + "engine_decode_runtime_state": self.owner._snapshot_engine_decode_runtime_state(), + "engine_job_registry": self.owner._snapshot_engine_job_registry(), + "engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(), + "engine_prepare_state": self.owner._snapshot_engine_prepare_state(), + "engine_finalize_state": self.owner._snapshot_engine_finalize_state(), + "engine_dispatcher_state": self.owner._snapshot_engine_dispatch_state(), + "active_batch": dict(worker_state.get("active_batch") or {}), + "prepare_state": dict(worker_state.get("prepare_state") or {}), + "bert_batch_worker_state": bert_worker_state, + "ref_semantic_worker_state": ref_semantic_worker_state, + "text_preprocessor_state": text_preprocessor_state, + } + + def get_scheduler_state(self) -> dict: + return self.scheduler_worker.snapshot() + + def get_runtime_state(self) -> dict: + model_state = self.model_registry.snapshot() + default_ref = self.reference_registry.get_default() + scheduler_state = self.get_scheduler_state() + request_registry = self.owner._snapshot_request_registry() + engine_policy = self._build_engine_policy_snapshot(request_registry, scheduler_state) + engine_arbiter_state = self.owner._snapshot_engine_arbiter_state() + engine_decode_runtime_state = self.owner._snapshot_engine_decode_runtime_state() + engine_job_registry = self.owner._snapshot_engine_job_registry() + engine_prepare_state = self.owner._snapshot_engine_prepare_state() + engine_finalize_state = self.owner._snapshot_engine_finalize_state() + engine_dispatcher_state = self.owner._snapshot_engine_dispatch_state() + engine_drained = self.owner._is_engine_drained() + return { + "message": "success", + "default_reference": { + "ref_audio_path": default_ref.ref_audio_path, + "updated_at": default_ref.updated_at, + }, + "model_registry": { + "generation": model_state.generation, + "t2s_generation": model_state.t2s_generation, + "vits_generation": model_state.vits_generation, + "t2s_weights_path": model_state.t2s_weights_path, + "vits_weights_path": model_state.vits_weights_path, + "updated_at": model_state.updated_at, + }, + "worker_state": scheduler_state, + "engine_policy": engine_policy, + "engine_arbiter_state": engine_arbiter_state, + "engine_decode_runtime_state": engine_decode_runtime_state, + "engine_job_registry": engine_job_registry, + "engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(), + "engine_prepare_state": engine_prepare_state, + "engine_finalize_state": engine_finalize_state, + "engine_dispatcher_state": engine_dispatcher_state, + "engine_drained": bool(engine_drained), + "request_registry": request_registry, + "stage_summary": self._build_stage_summary(request_registry, scheduler_state), + } + + def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: + if not self.scheduler_worker.wait_until_idle(timeout_sec=timeout_sec): + raise TimeoutError("scheduler worker did not drain before model reload") + + def set_refer_audio(self, refer_audio_path: str | None) -> dict: + if refer_audio_path in [None, ""]: + state = self.reference_registry.clear() + return {"message": "success", "default_ref_audio_path": state.ref_audio_path} + if not os.path.exists(str(refer_audio_path)): + raise FileNotFoundError(f"{refer_audio_path} not exists") + with self.management_lock: + with self.direct_tts_lock: + self.tts.set_ref_audio(str(refer_audio_path)) + state = self.reference_registry.set_default(str(refer_audio_path)) + return {"message": "success", "default_ref_audio_path": state.ref_audio_path} + + def set_gpt_weights(self, weights_path: str) -> dict: + if weights_path in ["", None]: + raise ValueError("gpt weight path is required") + with self.management_lock: + self._wait_for_safe_reload() + with self.direct_tts_lock: + self.tts.init_t2s_weights(weights_path) + self.tts.refresh_runtime_components() + state = self.model_registry.mark_t2s_reload(str(weights_path)) + return {"message": "success", "t2s_generation": state.t2s_generation, "generation": state.generation} + + def set_sovits_weights(self, weights_path: str) -> dict: + if weights_path in ["", None]: + raise ValueError("sovits weight path is required") + with self.management_lock: + self._wait_for_safe_reload() + with self.direct_tts_lock: + self.tts.init_vits_weights(weights_path) + self.tts.refresh_runtime_components() + state = self.model_registry.mark_vits_reload(str(weights_path)) + return {"message": "success", "vits_generation": state.vits_generation, "generation": state.generation} + + def handle_control(self, command: str) -> None: + if command == "restart": + if self.control_callbacks.restart is None: + os.execl(sys.executable, sys.executable, *sys.argv) + self.control_callbacks.restart() + return + if command == "exit": + if self.control_callbacks.exit is None: + os.kill(os.getpid(), signal.SIGTERM) + return + self.control_callbacks.exit() + return + raise ValueError(f"unsupported command: {command}") diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py new file mode 100644 index 00000000..65f0befe --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py @@ -0,0 +1,420 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Any, Callable, Dict, List, Optional + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( + EngineDecodeRuntimeOwner, + EngineDispatchTask, + EngineGpuPrepareTask, + EngineStatus, + EngineTaskQueueOwner, + SchedulerFinalizeTask, + SchedulerPendingJob, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker + + +class EngineStageCoordinator: + def __init__( + self, + *, + tts: TTS, + scheduler_worker: UnifiedSchedulerWorker, + prepare_queue_owner: EngineTaskQueueOwner, + finalize_queue_owner: EngineTaskQueueOwner, + dispatch_queue_owner: EngineTaskQueueOwner, + decode_runtime_owner: EngineDecodeRuntimeOwner, + update_request_state: Callable[[str, str, Optional[Dict[str, Any]]], None], + merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], + fail_request_state: Callable[[str, str], None], + get_engine_job: Callable[[str], SchedulerPendingJob | None], + register_engine_job: Callable[[SchedulerPendingJob], None], + fail_engine_jobs: Callable[[List[str], str], None], + complete_engine_job: Callable[..., None], + add_engine_prefill_time: Callable[[List[SchedulerPendingJob], float], None], + add_engine_merge_time: Callable[[List[str], float], None], + add_engine_decode_time: Callable[[List[str], float], None], + enqueue_engine_finished_items: Callable[[List[T2SFinishedItem]], None], + snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]], + snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]], + ) -> None: + self.tts = tts + self.scheduler_worker = scheduler_worker + self.prepare_queue_owner = prepare_queue_owner + self.finalize_queue_owner = finalize_queue_owner + self.dispatch_queue_owner = dispatch_queue_owner + self.decode_runtime_owner = decode_runtime_owner + self.update_request_state = update_request_state + self.merge_request_state_profile = merge_request_state_profile + self.fail_request_state = fail_request_state + self.get_engine_job = get_engine_job + self.register_engine_job = register_engine_job + self.fail_engine_jobs = fail_engine_jobs + self.complete_engine_job = complete_engine_job + self.add_engine_prefill_time = add_engine_prefill_time + self.add_engine_merge_time = add_engine_merge_time + self.add_engine_decode_time = add_engine_decode_time + self.enqueue_engine_finished_items = enqueue_engine_finished_items + self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state + self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state + self._notify_arbiter: Callable[[], None] | None = None + self._select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]] | None = None + self._mark_arbiter_tick: Callable[[str, str, bool], None] | None = None + self._wait_arbiter: Callable[[], None] | None = None + + def bind_arbiter( + self, + *, + notify_arbiter: Callable[[], None], + select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]], + mark_arbiter_tick: Callable[[str, str, bool], None], + wait_arbiter: Callable[[], None], + ) -> None: + self._notify_arbiter = notify_arbiter + self._select_stage = select_stage + self._mark_arbiter_tick = mark_arbiter_tick + self._wait_arbiter = wait_arbiter + + def notify_arbiter(self) -> None: + if self._notify_arbiter is not None: + self._notify_arbiter() + + @staticmethod + def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None: + if future.done(): + return + future.set_exception(error) + + def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + @staticmethod + def _resolve_prepare_future( + future: asyncio.Future, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if future.done(): + return + future.set_result(payload) + + def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + def _notify_prepare_result( + self, + task: EngineGpuPrepareTask, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload) + except RuntimeError: + pass + + async def prepare_state_via_engine_gpu_queue( + self, + *, + spec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + if engine_request_id not in [None, ""]: + self.update_request_state( + str(engine_request_id), + EngineStatus.GPU_PREPARING, + { + "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), + "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), + "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), + "text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms), + }, + ) + loop = asyncio.get_running_loop() + done_future = loop.create_future() + task = EngineGpuPrepareTask( + request_id=spec.request_id, + cpu_stage=cpu_stage, + done_loop=loop, + done_future=done_future, + engine_request_id=engine_request_id or spec.request_id, + enqueue_time=time.perf_counter(), + ) + self.prepare_queue_owner.enqueue(task) + self.notify_arbiter() + state, prepare_exec_started_at, prepare_exec_finished_at = await done_future + return state, prepare_exec_started_at, prepare_exec_finished_at + + def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is not None: + self.update_request_state( + job.engine_request_id, + EngineStatus.READY_FOR_FINALIZE, + { + "finish_reason": task.item.finish_reason, + "semantic_len": int(task.item.semantic_tokens.shape[0]), + "finish_idx": int(task.item.finish_idx), + }, + ) + self.finalize_queue_owner.enqueue_many(tasks) + self.notify_arbiter() + + def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + finalize_policy = self.scheduler_worker.get_finalize_batch_policy() + return self.finalize_queue_owner.take_finalize_batch( + finalize_mode=str(finalize_policy.get("finalize_mode", "async")), + batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)), + batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)), + use_vocoder=bool(self.tts.configs.use_vocoder), + ) + + async def enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + task = EngineDispatchTask( + request_id=state.request_id, + state=state, + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id or state.request_id, + timeout_sec=timeout_sec, + enqueue_time=time.perf_counter(), + ) + self.dispatch_queue_owner.enqueue(task) + self.notify_arbiter() + self.merge_request_state_profile( + task.engine_request_id or task.request_id, + { + "engine_dispatch_queue_depth_on_enqueue": int( + self.snapshot_engine_dispatch_state()["waiting_count"] + ), + }, + ) + return task + + def peek_queue_age_ms(self, queue_name: str) -> float: + if queue_name == "prepare": + return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") + if queue_name == "finalize": + return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time") + if queue_name == "decode_runtime_pending": + return self.decode_runtime_owner.pending_age_ms() + return self.dispatch_queue_owner.peek_oldest_age_ms("enqueue_time") + + def has_pending_work(self) -> bool: + if self.scheduler_worker.is_engine_decode_control_enabled(): + if self.decode_runtime_owner.has_pending_jobs(): + return True + if self.scheduler_worker.is_engine_decode_control_enabled() and self.snapshot_engine_decode_runtime_state().get( + "active_request_count", 0 + ) > 0: + return True + if self.prepare_queue_owner.has_items(): + return True + if self.finalize_queue_owner.has_items(): + return True + return self.dispatch_queue_owner.has_items() + + def run_engine_prepare_once(self) -> bool: + task = self.prepare_queue_owner.pop_left() + if task is None: + return False + queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) + try: + state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( + self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) + ) + state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, + ) + self.prepare_queue_owner.mark_completed(1) + self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) + return True + except Exception as exc: + task.error = str(exc) + self.fail_request_state(task.engine_request_id or task.request_id, str(exc)) + self._notify_prepare_error(task, exc) + return True + + def run_engine_finalize_once(self) -> bool: + tasks = self.take_engine_finalize_batch_nonblocking() + if not tasks: + return False + self.scheduler_worker.begin_finalize_execution(len(tasks)) + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return False + now = time.perf_counter() + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is not None: + job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) + for job, item in jobs_and_items: + self.update_request_state( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items) + for job, _ in jobs_and_items: + job.synth_ms += float(synth_ms) + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self.fail_engine_jobs([task.request_id for task in tasks], str(exc)) + finally: + self.scheduler_worker.end_finalize_execution(len(tasks)) + self.finalize_queue_owner.mark_completed(len(tasks), notify=True) + return True + + def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + if not bool(policy_snapshot.get("allowed", True)): + return False + dispatch_task = self.dispatch_queue_owner.pop_left() + if dispatch_task is None: + return False + dispatched_at = time.perf_counter() + dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0) + dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_policy_snapshot = dict(policy_snapshot) + try: + worker_job = self.scheduler_worker.submit( + state=dispatch_task.state, + speed_factor=dispatch_task.speed_factor, + sample_steps=dispatch_task.sample_steps, + media_type=dispatch_task.media_type, + prepare_wall_ms=dispatch_task.prepare_wall_ms, + prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, + done_loop=dispatch_task.done_loop, + done_future=dispatch_task.done_future, + engine_request_id=dispatch_task.engine_request_id, + timeout_sec=dispatch_task.timeout_sec, + skip_capacity_wait=True, + admission_wait_ms_override=0.0, + admission_snapshot_override=dict(worker_state), + engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms, + engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms, + enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(), + ) + dispatch_task.worker_job = worker_job + self.register_engine_job(worker_job) + if self.scheduler_worker.is_engine_decode_control_enabled(): + self.decode_runtime_owner.enqueue_pending_job(worker_job) + self.notify_arbiter() + self.dispatch_queue_owner.mark_completed(1) + return True + except Exception as exc: + dispatch_task.error = str(exc) + self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc)) + self._notify_dispatch_error(dispatch_task, exc) + return True + + def run_engine_decode_runtime_once(self) -> bool: + if not self.scheduler_worker.is_engine_decode_control_enabled(): + return False + runtime_state = self.snapshot_engine_decode_runtime_state() + pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking( + wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0 + ) + result = self.scheduler_worker.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=self.decode_runtime_owner.get_active_batch(), + external_bookkeeping=True, + ) + prefill_phase = dict(result.get("prefill_phase") or {}) + if prefill_phase.get("error"): + self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error"))) + else: + prefill_jobs = list(prefill_phase.get("pending_jobs") or []) + self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0))) + self.add_engine_merge_time( + [] if result.get("active_batch") is None else list(result["active_batch"].request_ids), + float(prefill_phase.get("merge_elapsed_s", 0.0)), + ) + self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or [])) + decode_phase = dict(result.get("decode_phase") or {}) + if decode_phase.get("error"): + self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error"))) + else: + self.add_engine_decode_time( + list(decode_phase.get("request_ids") or []), + float(decode_phase.get("decode_elapsed_s", 0.0)), + ) + self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or [])) + self.decode_runtime_owner.set_active_batch(result.get("active_batch")) + if result.get("executed", False): + self.decode_runtime_owner.refresh_state("engine_decode_cycle") + return bool(result.get("executed", False)) + + def run_engine_arbiter_loop(self) -> None: + if self._select_stage is None or self._mark_arbiter_tick is None or self._wait_arbiter is None: + raise RuntimeError("arbiter callbacks are not bound") + while True: + if not self.has_pending_work(): + self._mark_arbiter_tick("idle", "no_pending_work", True) + self._wait_arbiter() + continue + stage, reason, policy_snapshot, worker_state = self._select_stage() + policy_allowed = bool(policy_snapshot.get("allowed", True)) + executed = False + if stage == "prepare": + executed = self.run_engine_prepare_once() + elif stage == "finalize": + executed = self.run_engine_finalize_once() + elif stage == "decode_dispatch": + executed = self.run_engine_dispatch_once(policy_snapshot, worker_state) + elif stage == "decode_runtime": + executed = self.run_engine_decode_runtime_once() + if not executed: + self._mark_arbiter_tick("idle", f"{stage}_not_ready", policy_allowed) + self._wait_arbiter() + continue + self._mark_arbiter_tick(stage, reason, policy_allowed) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py new file mode 100644 index 00000000..04d9090f --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py @@ -0,0 +1,1510 @@ +from __future__ import annotations + +import asyncio +import os +import threading +import time +from collections import deque +from typing import Any, Callable, Deque, Dict, List, Optional + +import numpy as np +import torch + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState, decode_one_step, merge_active_batches, run_prefill_active_batch, run_scheduler_continuous +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, RuntimeStateCallbacks, SchedulerFinalizeTask, SchedulerJobRegistry, SchedulerPendingJob + + +class WorkerPrepareExecutor: + def __init__( + self, + tts: TTS, + on_state_change: Callable[[], None] | None = None, + ) -> None: + self.coordinator = PrepareCoordinator(tts) + self.on_state_change = on_state_change + + def _notify_state_change(self) -> None: + if self.on_state_change is None: + return + try: + self.on_state_change() + except Exception: + pass + + def snapshot(self) -> Dict[str, int]: + return dict(self.coordinator.snapshot()) + + def get_max_inflight(self) -> int: + return int(self.coordinator.snapshot().get("max_inflight", 0)) + + def is_idle(self) -> bool: + return int(self.coordinator.snapshot().get("inflight", 0)) <= 0 + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + try: + return await self.coordinator.prepare_state_profiled_async(spec, prepare_submit_at) + finally: + self._notify_state_change() + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + results = await asyncio.gather( + *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] + ) + return [state for state, _, _ in results] + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + try: + return await self.coordinator.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + finally: + self._notify_state_change() + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + try: + return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) + finally: + self._notify_state_change() + + +class WorkerFinalizeExecutor: + def __init__( + self, + tts: TTS, + on_state_change: Callable[[], None] | None = None, + external_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, + ) -> None: + self.tts = tts + self.on_state_change = on_state_change + self.external_submit = external_submit + self.condition = threading.Condition() + self.pending_tasks: Deque[SchedulerFinalizeTask] = deque() + self.pending_peak = 0 + self.inflight = 0 + self.inflight_peak = 0 + self.worker_count = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) + self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() + self.batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) + self.batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) + + def _notify_state_change(self) -> None: + if self.on_state_change is None: + return + try: + self.on_state_change() + except Exception: + pass + + def get_worker_count(self) -> int: + return int(self.worker_count) + + def get_batch_policy(self) -> Dict[str, Any]: + return { + "finalize_mode": str(self.finalize_mode), + "finalize_batch_max_items": int(self.batch_max_items), + "finalize_batch_wait_s": float(self.batch_wait_s), + } + + def get_pending_count(self) -> int: + with self.condition: + return int(len(self.pending_tasks)) + + def snapshot(self) -> Dict[str, Any]: + with self.condition: + return { + "finalize_pending": int(len(self.pending_tasks)), + "finalize_pending_peak": int(self.pending_peak), + "finalize_inflight": int(self.inflight), + "finalize_inflight_peak": int(self.inflight_peak), + "finalize_workers": int(self.worker_count), + "finalize_mode": str(self.finalize_mode), + "finalize_batch_max_items": int(self.batch_max_items), + "finalize_batch_wait_ms": float(self.batch_wait_s * 1000.0), + } + + def is_idle(self) -> bool: + with self.condition: + return self.inflight <= 0 and not self.pending_tasks + + def enqueue_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + if self.external_submit is not None: + self.external_submit(tasks) + self._notify_state_change() + return + with self.condition: + for task in tasks: + self.pending_tasks.append(task) + self.pending_peak = max(self.pending_peak, len(self.pending_tasks)) + self.condition.notify_all() + self._notify_state_change() + + def begin_execution(self, task_count: int) -> None: + if task_count <= 0: + return + with self.condition: + self.inflight += int(task_count) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self.condition.notify_all() + self._notify_state_change() + + def end_execution(self, task_count: int) -> None: + with self.condition: + self.inflight = max(0, self.inflight - int(task_count)) + self.condition.notify_all() + self._notify_state_change() + + def take_task_batch_blocking(self) -> List[SchedulerFinalizeTask]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + selected_tasks = [self.pending_tasks.popleft()] + if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: + self.inflight += len(selected_tasks) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self._notify_state_change() + return selected_tasks + batch_deadline = time.perf_counter() + self.batch_wait_s + while len(selected_tasks) < self.batch_max_items: + if not self.pending_tasks: + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.condition.wait(timeout=remaining) + continue + first_task = selected_tasks[0] + matched_index = None + for index, task in enumerate(self.pending_tasks): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is not None: + selected_tasks.append(self.pending_tasks[matched_index]) + del self.pending_tasks[matched_index] + continue + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.condition.wait(timeout=remaining) + self.inflight += len(selected_tasks) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self._notify_state_change() + return selected_tasks + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: + audio_fragment = self.tts.synthesize_audio_request_local( + semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), + phones=job.state.phones.detach().clone().unsqueeze(0), + prompt_semantic=job.state.prompt_semantic.detach().clone(), + prompt_phones=job.state.prompt_phones.detach().clone(), + refer_spec=( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ), + raw_audio=job.state.raw_audio.detach().clone(), + raw_sr=int(job.state.raw_sr), + speed=float(job.speed_factor), + sample_steps=int(job.sample_steps), + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + return self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + + def _synthesize_finished_audio_batch( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> List[tuple[int, np.ndarray]]: + semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] + phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] + refer_specs = [] + speeds = [] + sample_steps_list = [] + for job, _ in jobs_and_items: + refer_specs.append( + ( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ) + ) + speeds.append(float(job.speed_factor)) + sample_steps_list.append(int(job.sample_steps)) + audio_fragments = self.tts.synthesize_audio_requests_local_batched( + semantic_tokens_list=semantic_tokens_list, + phones_list=phones_list, + refer_specs=refer_specs, + speeds=speeds, + sample_steps_list=sample_steps_list, + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + results: List[tuple[int, np.ndarray]] = [] + for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): + results.append( + self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + ) + return results + + def synthesize_finalize_jobs( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> tuple[float, List[tuple[int, np.ndarray]]]: + if not jobs_and_items: + return 0.0, [] + self._sync_device() + synth_start = time.perf_counter() + if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: + job, item = jobs_and_items[0] + batch_results = [self._synthesize_finished_audio(job, item)] + else: + batch_results = self._synthesize_finished_audio_batch(jobs_and_items) + self._sync_device() + synth_ms = (time.perf_counter() - synth_start) * 1000.0 + return float(synth_ms), batch_results + + +class WorkerCompletionBridge: + def __init__(self, runtime_callbacks: RuntimeStateCallbacks | None = None) -> None: + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(job) + + def notify_done_future(self, job: SchedulerPendingJob) -> None: + if job.done_loop is None or job.done_future is None: + return + try: + job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) + except RuntimeError: + pass + + def runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + if request_id is None or self.runtime_callbacks.complete is None: + return + self.runtime_callbacks.complete(request_id, extra) + + def runtime_fail(self, request_id: str | None, error: str) -> None: + if request_id is None or self.runtime_callbacks.fail is None: + return + self.runtime_callbacks.fail(request_id, error) + + @staticmethod + def build_completed_job_result( + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + finished_at: float | None = None, + ) -> Dict[str, Any]: + finished_at = float(time.perf_counter() if finished_at is None else finished_at) + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + worker_residual_ms = max( + 0.0, + worker_total_ms + - queue_wait_ms + - job.prefill_ms + - job.merge_ms + - job.decode_ms + - job.finalize_wait_ms + - job.synth_ms, + ) + worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + job.result_ready_time = finished_at + prepare_profile = dict(job.state.prepare_profile) + result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "decode_admission_wait_ms": float(job.admission_wait_ms), + "engine_policy_wait_ms": float(job.engine_policy_wait_ms), + "engine_dispatch_wait_ms": float(job.engine_dispatch_wait_ms), + "prepare_ms": job.prepare_wall_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile_total_ms": job.prepare_profile_total_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "merge_ms": job.merge_ms, + "decode_ms": job.decode_ms, + "finalize_wait_ms": job.finalize_wait_ms, + "synth_ms": job.synth_ms, + "worker_residual_ms": worker_residual_ms, + "worker_other_ms": worker_other_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.result = result + return result + + @staticmethod + def build_runtime_complete_payload( + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + ) -> Dict[str, Any]: + return { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "sample_rate": int(sample_rate), + "worker_profile": dict(job.result or {}), + } + + def complete_job( + self, + job: SchedulerPendingJob, + *, + runtime_request_id: str | None, + runtime_extra: Optional[Dict[str, Any]] = None, + remove_job: Callable[[], None] | None = None, + on_job_finished: Callable[[], None] | None = None, + notify_waiters: Callable[[], None] | None = None, + ) -> None: + job.done_event.set() + self.notify_done_future(job) + if remove_job is not None: + remove_job() + if on_job_finished is not None: + on_job_finished() + if notify_waiters is not None: + notify_waiters() + self.runtime_complete(runtime_request_id, runtime_extra) + + def fail_job( + self, + job: SchedulerPendingJob, + *, + error: str, + remove_job: Callable[[], None] | None = None, + on_job_finished: Callable[[], None] | None = None, + notify_waiters: Callable[[], None] | None = None, + ) -> None: + job.error = str(error) + job.done_event.set() + self.notify_done_future(job) + if remove_job is not None: + remove_job() + if on_job_finished is not None: + on_job_finished() + if notify_waiters is not None: + notify_waiters() + self.runtime_fail(job.engine_request_id, str(error)) + + def complete_finalize_task( + self, + *, + condition: threading.Condition, + job_registry: SchedulerJobRegistry, + job: SchedulerPendingJob, + item: T2SFinishedItem, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + runtime_extra: Optional[Dict[str, Any]] = None + with condition: + if job_registry.get(item.request_id) is not job: + return + self.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) + runtime_extra = self.build_runtime_complete_payload(job, item, sample_rate=sample_rate) + self.complete_job( + job, + runtime_request_id=job.engine_request_id, + runtime_extra=runtime_extra, + on_job_finished=lambda: job_registry.mark_finished_and_remove(item.request_id), + notify_waiters=condition.notify_all, + ) + + def fail_jobs( + self, + *, + condition: threading.Condition, + job_registry: SchedulerJobRegistry, + request_ids: List[str], + error: str, + ) -> None: + if not request_ids: + return + with condition: + for request_id in request_ids: + job = job_registry.get(request_id) + if job is None: + continue + self.fail_job( + job, + error=error, + on_job_finished=lambda rid=request_id: job_registry.mark_finished_and_remove(rid), + ) + condition.notify_all() + + +class WorkerDecodeExecutor: + def __init__(self, tts: TTS, max_steps: int) -> None: + self.tts = tts + self.max_steps = int(max_steps) + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def execute_prefill_merge( + self, + *, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], + add_prefill_time: Callable[[List[str], float], None] | None, + add_merge_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + if not pending_jobs: + return { + "executed": False, + "active_batch": active_batch, + "pending_jobs": [], + "prefill_elapsed_s": 0.0, + "merge_elapsed_s": 0.0, + "finished_items": [], + "error": None, + "error_request_ids": [], + } + admitted_finished: List[T2SFinishedItem] = [] + prefill_elapsed_s = 0.0 + merge_elapsed_s = 0.0 + error: str | None = None + error_request_ids: List[str] = [] + try: + self._sync_device() + prefill_start = time.perf_counter() + mark_prefill_started(pending_jobs, prefill_start) + admitted_active_batch, admitted_finished = run_prefill_active_batch( + self.tts.t2s_model.model, + [job.state for job in pending_jobs], + max_steps=self.max_steps, + ) + self._sync_device() + prefill_elapsed_s = time.perf_counter() - prefill_start + if add_prefill_time is not None: + add_prefill_time([job.request_id for job in pending_jobs], prefill_elapsed_s) + if enqueue_finished is not None: + enqueue_finished(admitted_finished) + merge_start = time.perf_counter() + active_batch = merge_active_batches( + self.tts.t2s_model.model, + active_batch, + admitted_active_batch, + ) + merge_elapsed_s = time.perf_counter() - merge_start + if add_merge_time is not None: + add_merge_time( + [] if active_batch is None else list(active_batch.request_ids), + merge_elapsed_s, + ) + except Exception as exc: + error = str(exc) + error_request_ids = [job.request_id for job in pending_jobs] + if finalize_error is not None: + finalize_error(error_request_ids, error) + return { + "executed": True, + "active_batch": active_batch, + "pending_jobs": list(pending_jobs), + "prefill_elapsed_s": float(prefill_elapsed_s), + "merge_elapsed_s": float(merge_elapsed_s), + "finished_items": list(admitted_finished), + "error": error, + "error_request_ids": error_request_ids, + } + + def execute_decode_step( + self, + *, + active_batch: Optional[T2SActiveBatch], + add_decode_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + if active_batch is None: + return { + "executed": False, + "active_batch": None, + "request_ids": [], + "decode_elapsed_s": 0.0, + "finished_items": [], + "error": None, + "error_request_ids": [], + } + active_request_ids: List[str] = [] + step_finished: List[T2SFinishedItem] = [] + decode_elapsed_s = 0.0 + error: str | None = None + error_request_ids: List[str] = [] + try: + active_request_ids = [state.request_id for state in active_batch.states] + self._sync_device() + decode_start = time.perf_counter() + active_batch, step_finished = decode_one_step( + self.tts.t2s_model.model, + active_batch, + max_steps=self.max_steps, + ) + self._sync_device() + decode_elapsed_s = time.perf_counter() - decode_start + if add_decode_time is not None: + add_decode_time(active_request_ids, decode_elapsed_s) + if enqueue_finished is not None: + enqueue_finished(step_finished) + except Exception as exc: + error = str(exc) + error_request_ids = list(active_request_ids) + if finalize_error is not None: + finalize_error(error_request_ids, error) + active_batch = None + return { + "executed": True, + "active_batch": active_batch, + "request_ids": active_request_ids, + "decode_elapsed_s": float(decode_elapsed_s), + "finished_items": list(step_finished), + "error": error, + "error_request_ids": error_request_ids, + } + + def execute_decode_cycle( + self, + *, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], + add_prefill_time: Callable[[List[str], float], None] | None, + add_merge_time: Callable[[List[str], float], None] | None, + add_decode_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + result = { + "executed": False, + "prefill_merge_executed": False, + "decode_step_executed": False, + "active_batch": active_batch, + "prefill_phase": {}, + "decode_phase": {}, + } + prefill_phase = self.execute_prefill_merge( + pending_jobs=list(pending_jobs), + active_batch=result["active_batch"], + mark_prefill_started=mark_prefill_started, + add_prefill_time=add_prefill_time, + add_merge_time=add_merge_time, + enqueue_finished=enqueue_finished, + finalize_error=finalize_error, + ) + prefill_executed = bool(prefill_phase.get("executed", False)) + result["prefill_phase"] = prefill_phase + result["active_batch"] = prefill_phase.get("active_batch") + if prefill_executed: + result["executed"] = True + result["prefill_merge_executed"] = True + decode_phase = self.execute_decode_step( + active_batch=result["active_batch"], + add_decode_time=add_decode_time, + enqueue_finished=enqueue_finished, + finalize_error=finalize_error, + ) + decode_executed = bool(decode_phase.get("executed", False)) + result["decode_phase"] = decode_phase + result["active_batch"] = decode_phase.get("active_batch") + if decode_executed: + result["executed"] = True + result["decode_step_executed"] = True + return result + + +class WorkerDecodeLegacyShell: + def __init__(self, condition: threading.Condition, micro_batch_wait_s: float) -> None: + self.condition = condition + self.micro_batch_wait_s = float(micro_batch_wait_s) + self.pending_jobs: List[SchedulerPendingJob] = [] + self.active_batch: T2SActiveBatch | None = None + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any] | None: + if active_batch is None: + return None + return { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": ( + int(active_batch.step_indices.max().item()) + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0 + else 0 + ), + } + + def current_backlog_locked(self) -> int: + running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids) + return int(len(self.pending_jobs) + running_requests) + + def enqueue_pending_job_locked(self, job: SchedulerPendingJob) -> None: + self.pending_jobs.append(job) + + def snapshot_locked(self) -> Dict[str, Any]: + active_batch_summary = self._summarize_active_batch(self.active_batch) + executor_local_pending_jobs = int(len(self.pending_jobs)) + executor_local_running_requests = 0 if self.active_batch is None else int(len(self.active_batch.request_ids)) + executor_local_has_work = bool(self.pending_jobs or self.active_batch is not None) + return { + "executor_local_pending_jobs": executor_local_pending_jobs, + "executor_local_running_requests": executor_local_running_requests, + "executor_local_has_work": executor_local_has_work, + "executor_local_active_batch": active_batch_summary, + } + + def is_idle_locked(self) -> bool: + return self.active_batch is None and not self.pending_jobs + + def take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs and self.active_batch is None: + self.condition.wait(timeout=self.micro_batch_wait_s) + elif wait_for_batch and self.pending_jobs: + self.condition.wait(timeout=self.micro_batch_wait_s) + if not self.pending_jobs: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs: + return [] + if wait_for_batch: + oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) + if (time.perf_counter() - oldest_enqueue_time) < self.micro_batch_wait_s: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def has_decode_runtime_work(self) -> bool: + with self.condition: + return bool(self.pending_jobs or self.active_batch is not None) + + def build_runtime_summary_locked(self, *, total_cycles: int, prefill_cycles: int, step_cycles: int, last_event: str) -> Dict[str, Any]: + active_request_ids = [] if self.active_batch is None else list(self.active_batch.request_ids) + decode_step_index_max = 0 + prefill_done = False + if self.active_batch is not None: + prefill_done = bool(self.active_batch.prefill_done) + if self.active_batch.step_indices is not None and self.active_batch.step_indices.numel() > 0: + decode_step_index_max = int(self.active_batch.step_indices.max().item()) + return { + "pending_jobs": int(len(self.pending_jobs)), + "active_request_count": int(len(active_request_ids)), + "active_request_ids": active_request_ids[:32], + "prefill_done": bool(prefill_done), + "decode_step_index_max": int(decode_step_index_max), + "total_cycles": int(total_cycles), + "prefill_cycles": int(prefill_cycles), + "step_cycles": int(step_cycles), + "has_work": bool(self.pending_jobs or self.active_batch is not None), + "last_event": str(last_event), + "updated_at": float(time.perf_counter()), + } + + def run_prefill_merge_once_nonblocking( + self, + *, + external_pending_jobs: Optional[List[SchedulerPendingJob]], + external_active_batch: Optional[T2SActiveBatch], + execute_prefill_merge: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], + ) -> Dict[str, Any]: + pending_jobs = ( + list(external_pending_jobs) + if external_pending_jobs is not None + else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) + ) + active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch + result = execute_prefill_merge(pending_jobs, active_batch) + if external_pending_jobs is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + return result + + def run_decode_step_once_nonblocking( + self, + *, + external_active_batch: Optional[T2SActiveBatch], + execute_decode_step: Callable[[Optional[T2SActiveBatch]], Dict[str, Any]], + ) -> Dict[str, Any]: + active_batch = self.active_batch if external_active_batch is None else external_active_batch + result = execute_decode_step(active_batch) + if external_active_batch is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + return result + + def run_decode_cycle_nonblocking( + self, + *, + external_pending_jobs: Optional[List[SchedulerPendingJob]], + external_active_batch: Optional[T2SActiveBatch], + execute_decode_cycle: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], + on_cycle_executed: Callable[[Dict[str, Any]], None] | None, + ) -> Dict[str, Any]: + pending_jobs = ( + list(external_pending_jobs) + if external_pending_jobs is not None + else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) + ) + active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch + result = execute_decode_cycle(pending_jobs, active_batch) + if external_pending_jobs is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + if result.get("executed") and on_cycle_executed is not None: + on_cycle_executed(result) + return result + + def run_loop( + self, + *, + run_decode_cycle_nonblocking: Callable[[], Dict[str, Any]], + ) -> None: + while True: + executed = run_decode_cycle_nonblocking() + if executed.get("executed"): + continue + wait_for_batch = self.active_batch is None + pending_jobs = self.take_pending_snapshot(wait_for_batch=wait_for_batch) + if pending_jobs: + with self.condition: + self.pending_jobs = pending_jobs + self.pending_jobs + self.condition.notify_all() + continue + time.sleep(self.micro_batch_wait_s) + + +class WorkerDecodeRuntimeTracker: + def __init__( + self, + runtime_callbacks: RuntimeStateCallbacks | None = None, + ) -> None: + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + self.total_cycles = 0 + self.prefill_cycles = 0 + self.step_cycles = 0 + + def get_counters(self) -> Dict[str, int]: + return { + "total_cycles": int(self.total_cycles), + "prefill_cycles": int(self.prefill_cycles), + "step_cycles": int(self.step_cycles), + } + + def record_cycle(self, result: Dict[str, Any]) -> None: + if not bool(result.get("executed")): + return + self.total_cycles += 1 + if bool(result.get("prefill_merge_executed")): + self.prefill_cycles += 1 + if bool(result.get("decode_step_executed")): + self.step_cycles += 1 + + def build_runtime_summary_locked( + self, + *, + legacy_shell: WorkerDecodeLegacyShell, + last_event: str, + ) -> Dict[str, Any]: + return legacy_shell.build_runtime_summary_locked( + total_cycles=int(self.total_cycles), + prefill_cycles=int(self.prefill_cycles), + step_cycles=int(self.step_cycles), + last_event=str(last_event), + ) + + def notify_runtime_update_locked( + self, + *, + legacy_shell: WorkerDecodeLegacyShell, + last_event: str, + ) -> None: + if self.runtime_callbacks.decode_runtime_update is None: + return + snapshot = self.build_runtime_summary_locked( + legacy_shell=legacy_shell, + last_event=last_event, + ) + self.runtime_callbacks.decode_runtime_update(snapshot) + + +class UnifiedSchedulerWorker: + def __init__( + self, + tts: TTS, + max_steps: int = 1500, + micro_batch_wait_ms: int = 5, + runtime_callbacks: RuntimeStateCallbacks | None = None, + external_finalize_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, + ): + self.tts = tts + self.max_steps = int(max_steps) + self.micro_batch_wait_s = float(micro_batch_wait_ms) / 1000.0 + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + self.condition = threading.Condition() + self.completion_bridge = WorkerCompletionBridge(self.runtime_callbacks) + self.decode_executor = WorkerDecodeExecutor(tts, max_steps=max_steps) + self.decode_legacy_shell = WorkerDecodeLegacyShell(self.condition, self.micro_batch_wait_s) + self.decode_runtime_tracker = WorkerDecodeRuntimeTracker(self.runtime_callbacks) + self.prepare_executor = WorkerPrepareExecutor(tts, on_state_change=self._notify_worker_state_change) + self.finalize_executor = WorkerFinalizeExecutor( + tts, + on_state_change=self._notify_worker_state_change, + external_submit=external_finalize_submit, + ) + self.decode_backlog_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_DECODE_BACKLOG_MAX", "0"))) + self.finalize_pending_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_FINALIZE_PENDING_MAX", "0"))) + self.engine_decode_control_enabled = ( + str(os.environ.get("GPTSOVITS_ENGINE_DRIVE_DECODE", "0")).strip().lower() in {"1", "true", "yes", "on"} + ) + self.job_registry = SchedulerJobRegistry(self.condition) + self.worker_thread: threading.Thread | None = None + if not self.engine_decode_control_enabled: + self.worker_thread = threading.Thread(target=self._run_loop, name="unified-t2s-scheduler-worker", daemon=True) + self.worker_thread.start() + self.finalize_threads = [] + if external_finalize_submit is None: + self.finalize_threads = [ + threading.Thread( + target=self._run_finalize_loop, + name=f"unified-t2s-finalize-{worker_index}", + daemon=True, + ) + for worker_index in range(self.finalize_executor.get_worker_count()) + ] + for finalize_thread in self.finalize_threads: + finalize_thread.start() + + def _notify_worker_state_change(self) -> None: + with self.condition: + self.condition.notify_all() + + def _current_decode_backlog_locked(self) -> int: + return self.decode_legacy_shell.current_backlog_locked() + + def get_micro_batch_wait_s(self) -> float: + return float(self.micro_batch_wait_s) + + def is_engine_decode_control_enabled(self) -> bool: + return bool(self.engine_decode_control_enabled) + + def get_prepare_max_inflight(self) -> int: + return int(self.prepare_executor.get_max_inflight()) + + def get_capacity_limits(self) -> Dict[str, int]: + return { + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + } + + def get_finalize_batch_policy(self) -> Dict[str, Any]: + return dict(self.finalize_executor.get_batch_policy()) + + def get_decode_runtime_counters(self) -> Dict[str, int]: + with self.condition: + return self.decode_runtime_tracker.get_counters() + + def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]: + decode_backlog = self._current_decode_backlog_locked() + finalize_pending = int(self.finalize_executor.get_pending_count()) + prepare_inflight = int(self.prepare_executor.snapshot()["inflight"]) + blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max + blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max + return ( + not blocked_decode and not blocked_finalize, + { + "decode_backlog": decode_backlog, + "finalize_pending": finalize_pending, + "prepare_inflight": prepare_inflight, + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + }, + ) + + def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]: + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + last_snapshot: Dict[str, int] = {} + while True: + with self.condition: + allowed, snapshot = self._can_accept_submit_locked() + last_snapshot = snapshot + if allowed: + return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot + if deadline is not None and time.perf_counter() >= deadline: + raise TimeoutError( + "scheduler submit admission timeout " + f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})" + ) + self.condition.wait(timeout=self.micro_batch_wait_s) + + def _admission_snapshot_locked(self) -> Dict[str, int]: + _, snapshot = self._can_accept_submit_locked() + return snapshot + + async def submit_async( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + skip_capacity_wait: bool = False, + admission_wait_ms_override: float | None = None, + admission_snapshot_override: Dict[str, Any] | None = None, + engine_policy_wait_ms: float = 0.0, + engine_dispatch_wait_ms: float = 0.0, + enqueue_pending: bool = True, + ) -> SchedulerPendingJob: + return await asyncio.to_thread( + self.submit, + state, + speed_factor, + sample_steps, + media_type, + prepare_wall_ms, + prepare_profile_total_ms, + done_loop, + done_future, + engine_request_id, + timeout_sec, + skip_capacity_wait, + admission_wait_ms_override, + admission_snapshot_override, + engine_policy_wait_ms, + engine_dispatch_wait_ms, + enqueue_pending, + ) + + def snapshot(self) -> dict: + with self.condition: + prepare_state = self.prepare_executor.snapshot() + finalize_state = self.finalize_executor.snapshot() + shell_state = self.decode_legacy_shell.snapshot_locked() + decode_runtime_counters = self.decode_runtime_tracker.get_counters() + engine_owned_decode_state = bool(self.engine_decode_control_enabled) + active_batch_summary = shell_state.get("executor_local_active_batch") + executor_local_pending_jobs = int(shell_state.get("executor_local_pending_jobs", 0)) + executor_local_running_requests = int(shell_state.get("executor_local_running_requests", 0)) + executor_local_has_work = bool(shell_state.get("executor_local_has_work", False)) + return { + "pending_jobs": 0 if engine_owned_decode_state else executor_local_pending_jobs, + "running_requests": 0 if engine_owned_decode_state else executor_local_running_requests, + "engine_decode_control_enabled": bool(self.engine_decode_control_enabled), + "legacy_state_owner_mode": not engine_owned_decode_state, + "decode_state_owner": "engine" if engine_owned_decode_state else "worker", + "decode_runtime_has_work": False if engine_owned_decode_state else executor_local_has_work, + "executor_local_pending_jobs": executor_local_pending_jobs, + "executor_local_running_requests": executor_local_running_requests, + "executor_local_has_work": executor_local_has_work, + "decode_runtime_total_cycles": int(decode_runtime_counters.get("total_cycles", 0)), + "decode_runtime_prefill_cycles": int(decode_runtime_counters.get("prefill_cycles", 0)), + "decode_runtime_step_cycles": int(decode_runtime_counters.get("step_cycles", 0)), + "prepare_inflight": prepare_state["inflight"], + "prepare_peak_inflight": prepare_state["peak_inflight"], + "prepare_max_inflight": prepare_state.get("max_inflight", 0), + "prepare_state": dict(prepare_state), + **finalize_state, + "decode_backlog_max": self.decode_backlog_max, + "finalize_pending_max": self.finalize_pending_max, + "active_batch": {} if engine_owned_decode_state else active_batch_summary, + "executor_local_active_batch": active_batch_summary if engine_owned_decode_state else None, + "total_submitted": self.job_registry.submitted_count(), + "total_finished": self.job_registry.finished_count(), + "drained": self.is_drained(), + } + + def is_drained(self) -> bool: + with self.condition: + return ( + self.decode_legacy_shell.is_idle_locked() + and self.job_registry.is_empty() + and self.prepare_executor.is_idle() + and self.finalize_executor.is_idle() + ) + + def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: + deadline = time.perf_counter() + max(0.0, timeout_sec) + while time.perf_counter() < deadline: + if self.is_drained(): + return True + time.sleep(poll_interval_sec) + return self.is_drained() + + def submit( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + skip_capacity_wait: bool = False, + admission_wait_ms_override: float | None = None, + admission_snapshot_override: Dict[str, Any] | None = None, + engine_policy_wait_ms: float = 0.0, + engine_dispatch_wait_ms: float = 0.0, + enqueue_pending: bool = True, + ) -> SchedulerPendingJob: + if skip_capacity_wait: + with self.condition: + admission_snapshot = ( + dict(admission_snapshot_override) + if admission_snapshot_override is not None + else dict(self._admission_snapshot_locked()) + ) + admission_wait_ms = 0.0 if admission_wait_ms_override is None else float(admission_wait_ms_override) + else: + admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec) + job = SchedulerPendingJob( + request_id=state.request_id, + state=state, + done_event=threading.Event(), + done_loop=done_loop, + done_future=done_future, + enqueue_time=time.perf_counter(), + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + admission_wait_ms=float(admission_wait_ms), + engine_policy_wait_ms=float(engine_policy_wait_ms), + engine_dispatch_wait_ms=float(engine_dispatch_wait_ms), + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + engine_request_id=engine_request_id or state.request_id, + ) + with self.condition: + self.job_registry.register(job, keep_job=not self.engine_decode_control_enabled) + if enqueue_pending: + self.decode_legacy_shell.enqueue_pending_job_locked(job) + self.condition.notify_all() + if enqueue_pending: + self._notify_decode_runtime_state("submit") + self._runtime_update( + job.engine_request_id, + EngineStatus.QUEUED, + { + "scheduler_request_id": job.request_id, + "decode_admission_wait_ms": float(admission_wait_ms), + "engine_policy_wait_ms": float(engine_policy_wait_ms), + "engine_dispatch_wait_ms": float(engine_dispatch_wait_ms), + "admission_snapshot": dict(admission_snapshot), + }, + ) + return job + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_executor.prepare_state_profiled_async(spec, prepare_submit_at) + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + return await self.prepare_executor.prepare_states_batch_async(specs) + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + return await self.prepare_executor.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage) + + def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: + with self.condition: + for job in pending_jobs: + job.first_schedule_time = float(started_at) + self._runtime_update( + job.engine_request_id, + EngineStatus.GPU_PREPARING, + {"scheduler_request_id": job.request_id, "prefill_started_at": float(started_at)}, + ) + + def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.prefill_ms += delta_ms + + def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + activate_request_ids: List[str] = [] + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) + job.decode_ms += delta_ms + job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None) + + def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.finalize_wait_ms += float(delta_ms) + + def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + tasks: List[SchedulerFinalizeTask] = [] + with self.condition: + for item in items: + job = self.job_registry.get(item.request_id) + if job is not None: + self._runtime_update( + job.engine_request_id, + EngineStatus.READY_FOR_FINALIZE, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + }, + ) + tasks.append(SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at)) + self.finalize_executor.enqueue_tasks(tasks) + + def begin_finalize_execution(self, task_count: int) -> None: + self.finalize_executor.begin_execution(task_count) + + def end_finalize_execution(self, task_count: int) -> None: + self.finalize_executor.end_execution(task_count) + + def record_external_job_done(self, request_id: str) -> None: + with self.condition: + self.job_registry.mark_finished_and_remove(request_id) + self.condition.notify_all() + + def synthesize_finalize_jobs( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> tuple[float, List[tuple[int, np.ndarray]]]: + return self.finalize_executor.synthesize_finalize_jobs(jobs_and_items) + + def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: + self.completion_bridge.complete_finalize_task( + condition=self.condition, + job_registry=self.job_registry, + job=job, + item=item, + sample_rate=sample_rate, + audio_data=audio_data, + ) + + def _finalize_error(self, request_ids: List[str], error: str) -> None: + self.completion_bridge.fail_jobs( + condition=self.condition, + job_registry=self.job_registry, + request_ids=request_ids, + error=error, + ) + + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(job) + + def _notify_done_future(self, job: SchedulerPendingJob) -> None: + self.completion_bridge.notify_done_future(job) + + def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + if request_id is None or self.runtime_callbacks.update is None: + return + self.runtime_callbacks.update(request_id, status, extra) + + def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + self.completion_bridge.runtime_complete(request_id, extra) + + def _runtime_fail(self, request_id: str | None, error: str) -> None: + self.completion_bridge.runtime_fail(request_id, error) + + def _build_decode_runtime_summary_locked(self, last_event: str) -> Dict[str, Any]: + return self.decode_runtime_tracker.build_runtime_summary_locked( + legacy_shell=self.decode_legacy_shell, + last_event=str(last_event), + ) + + def _notify_decode_runtime_state(self, last_event: str) -> None: + with self.condition: + self.decode_runtime_tracker.notify_runtime_update_locked( + legacy_shell=self.decode_legacy_shell, + last_event=str(last_event), + ) + + def _record_decode_runtime_cycle(self, result: Dict[str, Any]) -> None: + with self.condition: + self.decode_runtime_tracker.record_cycle(result) + + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.decode_legacy_shell.take_pending_snapshot(wait_for_batch) + + def _take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.decode_legacy_shell.take_pending_snapshot_nonblocking(wait_for_batch) + + def has_decode_runtime_work(self) -> bool: + return self.decode_legacy_shell.has_decode_runtime_work() + + def execute_prefill_merge( + self, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + return self.decode_executor.execute_prefill_merge( + pending_jobs=pending_jobs, + active_batch=active_batch, + mark_prefill_started=self._mark_prefill_started, + add_prefill_time=None if external_bookkeeping else self._add_prefill_time, + add_merge_time=None if external_bookkeeping else self._add_merge_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + + def execute_decode_step( + self, + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + return self.decode_executor.execute_decode_step( + active_batch=active_batch, + add_decode_time=None if external_bookkeeping else self._add_decode_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + + def execute_decode_cycle( + self, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_executor.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=active_batch, + mark_prefill_started=self._mark_prefill_started, + add_prefill_time=None if external_bookkeeping else self._add_prefill_time, + add_merge_time=None if external_bookkeeping else self._add_merge_time, + add_decode_time=None if external_bookkeeping else self._add_decode_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + self._record_decode_runtime_cycle(result) + return result + + def run_prefill_merge_once_nonblocking( + self, + external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_prefill_merge_once_nonblocking( + external_pending_jobs=external_pending_jobs, + external_active_batch=external_active_batch, + execute_prefill_merge=lambda batch_jobs, batch_state: self.execute_prefill_merge( + pending_jobs=batch_jobs, + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + ) + if emit_runtime_state: + self._notify_decode_runtime_state("prefill_merge") + return result + + def run_decode_step_once_nonblocking( + self, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_decode_step_once_nonblocking( + external_active_batch=external_active_batch, + execute_decode_step=lambda batch_state: self.execute_decode_step( + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + ) + if emit_runtime_state: + self._notify_decode_runtime_state("decode_step") + return result + + def run_decode_cycle_nonblocking( + self, + external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_decode_cycle_nonblocking( + external_pending_jobs=external_pending_jobs, + external_active_batch=external_active_batch, + execute_decode_cycle=lambda batch_jobs, batch_state: self.execute_decode_cycle( + pending_jobs=batch_jobs, + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + on_cycle_executed=None, + ) + if result.get("executed") and emit_runtime_state: + self._notify_decode_runtime_state("decode_cycle") + return result + + def execute_finalize_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + with self.condition: + for task in tasks: + job = self.job_registry.get(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return + now = time.perf_counter() + for task in tasks: + self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) + for job, item in jobs_and_items: + self._runtime_update( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.synthesize_finalize_jobs(jobs_and_items) + with self.condition: + for job, _ in jobs_and_items: + tracked_job = self.job_registry.get(job.request_id) + if tracked_job is not None: + tracked_job.synth_ms += synth_ms + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self._finalize_error([task.request_id for task in tasks], str(exc)) + finally: + self.finalize_executor.end_execution(len(tasks)) + + def _run_finalize_loop(self) -> None: + while True: + tasks = self.finalize_executor.take_task_batch_blocking() + self.execute_finalize_tasks(tasks) + + def _run_loop(self) -> None: + self.decode_legacy_shell.run_loop( + run_decode_cycle_nonblocking=lambda: self.run_decode_cycle_nonblocking() + ) + +