From a3a5aad15707ac984d0dae58794f18c99eab77c1 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Wed, 11 Mar 2026 20:49:41 +0800 Subject: [PATCH] Add unified engine components for TTS processing and state management Introduce new modules including unified_engine_component_models, unified_engine_component_policy, unified_engine_component_registry, unified_engine_component_runtime, unified_engine_worker_completion, and unified_engine_worker_decode. These additions enhance the TTS framework by providing structured models for request handling, engine policies, and worker execution, significantly improving the architecture and maintainability of the system. The new components support asynchronous operations and optimize overall performance through better state management and processing capabilities. --- .../unified_engine_component_models.py | 120 ++ .../unified_engine_component_policy.py | 335 ++++ .../unified_engine_component_registry.py | 381 +++++ .../unified_engine_component_runtime.py | 334 ++++ .../unified_engine_components.py | 1213 +------------- .../TTS_infer_pack/unified_engine_worker.py | 1471 +---------------- .../unified_engine_worker_completion.py | 198 +++ .../unified_engine_worker_decode.py | 430 +++++ .../unified_engine_worker_execution.py | 164 ++ .../unified_engine_worker_finalize.py | 234 +++ .../unified_engine_worker_prepare.py | 71 + .../unified_engine_worker_runtime.py | 170 ++ .../unified_engine_worker_submit.py | 256 +++ 13 files changed, 2772 insertions(+), 2605 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_worker_completion.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_worker_decode.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_worker_execution.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_worker_runtime.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py new file mode 100644 index 00000000..2c0cc9ac --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, Generator, List, Optional + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec + + +@dataclass +class RuntimeControlCallbacks: + restart: Callable[[], None] | None = None + exit: Callable[[], None] | None = None + + +@dataclass +class DirectTTSExecution: + media_type: str + streaming: bool + audio_generator: Optional[Generator[bytes, None, None]] = None + audio_bytes: Optional[bytes] = None + request_id: Optional[str] = None + + +@dataclass +class NormalizedEngineRequest: + request_id: str + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + aux_ref_audio_paths: List[str] | None = None + top_k: int = 15 + top_p: float = 1.0 + temperature: float = 1.0 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + ready_step: int = 0 + text_split_method: str = "cut5" + batch_size: int = 1 + batch_threshold: float = 0.75 + split_bucket: bool = False + speed_factor: float = 1.0 + fragment_interval: float = 0.3 + seed: int = -1 + media_type: str = "wav" + streaming_mode: bool | int = False + return_fragment: bool = False + fixed_length_chunk: bool = False + response_streaming: bool = False + parallel_infer: bool = False + sample_steps: int = 32 + super_sampling: bool = False + overlap_length: int = 2 + min_chunk_length: int = 16 + timeout_sec: float | None = None + + def to_payload(self) -> Dict[str, Any]: + return { + "request_id": self.request_id, + "text": self.text, + "text_lang": self.text_lang, + "ref_audio_path": self.ref_audio_path, + "aux_ref_audio_paths": list(self.aux_ref_audio_paths) if self.aux_ref_audio_paths else None, + "prompt_text": self.prompt_text, + "prompt_lang": self.prompt_lang, + "top_k": self.top_k, + "top_p": self.top_p, + "temperature": self.temperature, + "text_split_method": self.text_split_method, + "batch_size": self.batch_size, + "batch_threshold": self.batch_threshold, + "speed_factor": self.speed_factor, + "split_bucket": self.split_bucket, + "fragment_interval": self.fragment_interval, + "seed": self.seed, + "media_type": self.media_type, + "streaming_mode": self.streaming_mode, + "return_fragment": self.return_fragment, + "fixed_length_chunk": self.fixed_length_chunk, + "response_streaming": self.response_streaming, + "parallel_infer": self.parallel_infer, + "repetition_penalty": self.repetition_penalty, + "sample_steps": self.sample_steps, + "super_sampling": self.super_sampling, + "overlap_length": self.overlap_length, + "min_chunk_length": self.min_chunk_length, + "early_stop_num": self.early_stop_num, + "ready_step": self.ready_step, + "timeout_sec": self.timeout_sec, + } + + def to_scheduler_spec(self) -> SchedulerRequestSpec: + return SchedulerRequestSpec( + request_id=self.request_id, + ref_audio_path=Path(self.ref_audio_path), + prompt_text=self.prompt_text, + prompt_lang=self.prompt_lang, + text=self.text, + text_lang=self.text_lang, + top_k=self.top_k, + top_p=self.top_p, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty, + early_stop_num=self.early_stop_num, + ready_step=self.ready_step, + ) + + +@dataclass +class SchedulerDebugExecution: + payload: Dict[str, Any] + + +@dataclass +class SchedulerSubmitExecution: + audio_bytes: bytes + media_type: str + headers: Dict[str, str] diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py new file mode 100644 index 00000000..b6c5ca4d --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import asyncio +import threading +import time +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import EngineStatus + + +@dataclass +class EnginePolicyConfig: + enabled: bool = True + poll_wait_ms: float = 5.0 + decode_backlog_soft_max: int = 0 + finalize_pending_soft_max: int = 0 + prepare_inflight_soft_max: int = 0 + active_decode_soft_max: int = 0 + ready_for_prefill_soft_max: int = 0 + active_request_soft_max: int = 0 + + def to_dict(self) -> Dict[str, Any]: + return { + "enabled": bool(self.enabled), + "poll_wait_ms": float(self.poll_wait_ms), + "decode_backlog_soft_max": int(self.decode_backlog_soft_max), + "finalize_pending_soft_max": int(self.finalize_pending_soft_max), + "prepare_inflight_soft_max": int(self.prepare_inflight_soft_max), + "active_decode_soft_max": int(self.active_decode_soft_max), + "ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max), + "active_request_soft_max": int(self.active_request_soft_max), + } + + +@dataclass +class EngineArbiterConfig: + poll_wait_ms: float = 5.0 + decode_burst: int = 4 + prepare_aging_ms: float = 10.0 + finalize_aging_ms: float = 10.0 + + def to_dict(self) -> Dict[str, Any]: + return { + "poll_wait_ms": float(self.poll_wait_ms), + "decode_burst": int(self.decode_burst), + "prepare_aging_ms": float(self.prepare_aging_ms), + "finalize_aging_ms": float(self.finalize_aging_ms), + } + + +@dataclass +class EngineArbiterState: + total_ticks: int = 0 + total_idle_ticks: int = 0 + total_prepare_dispatches: int = 0 + total_decode_dispatches: int = 0 + total_decode_runtime_ticks: int = 0 + total_finalize_dispatches: int = 0 + decode_budget_remaining: int = 0 + last_stage: str = "idle" + last_reason: str = "init" + last_observed_at: float = 0.0 + last_policy_allowed: bool = True + + +class EnginePolicyArbiterController: + def __init__( + self, + *, + policy_config: EnginePolicyConfig, + arbiter_config: EngineArbiterConfig, + snapshot_request_registry: Callable[[], Dict[str, Any]], + get_worker_state: Callable[[], Dict[str, Any]], + snapshot_prepare_state: Callable[[], Dict[str, Any]], + snapshot_finalize_state: Callable[[], Dict[str, Any]], + snapshot_dispatch_state: Callable[[], Dict[str, Any]], + snapshot_decode_runtime_state: Callable[[], Dict[str, Any]], + snapshot_job_registry: Callable[[], Dict[str, Any]], + peek_queue_age_ms: Callable[[str], float], + merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], + ) -> None: + self.policy_config = policy_config + self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0) + self.arbiter_config = arbiter_config + self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0) + self.condition = threading.Condition() + self.state = EngineArbiterState( + decode_budget_remaining=int(self.arbiter_config.decode_burst), + last_observed_at=time.perf_counter(), + ) + self.snapshot_request_registry = snapshot_request_registry + self.get_worker_state = get_worker_state + self.snapshot_prepare_state = snapshot_prepare_state + self.snapshot_finalize_state = snapshot_finalize_state + self.snapshot_dispatch_state = snapshot_dispatch_state + self.snapshot_decode_runtime_state = snapshot_decode_runtime_state + self.snapshot_job_registry = snapshot_job_registry + self.peek_queue_age_ms = peek_queue_age_ms + self.merge_request_state_profile = merge_request_state_profile + + def snapshot_state(self) -> Dict[str, Any]: + with self.condition: + return { + "config": self.arbiter_config.to_dict(), + "total_ticks": int(self.state.total_ticks), + "total_idle_ticks": int(self.state.total_idle_ticks), + "total_prepare_dispatches": int(self.state.total_prepare_dispatches), + "total_decode_dispatches": int(self.state.total_decode_dispatches), + "total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks), + "total_finalize_dispatches": int(self.state.total_finalize_dispatches), + "decode_budget_remaining": int(self.state.decode_budget_remaining), + "last_stage": str(self.state.last_stage), + "last_reason": str(self.state.last_reason), + "last_policy_allowed": bool(self.state.last_policy_allowed), + "last_observed_at": float(self.state.last_observed_at), + } + + def notify(self) -> None: + with self.condition: + self.condition.notify_all() + + def wait(self) -> None: + with self.condition: + self.condition.wait(timeout=self.arbiter_poll_s) + + def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + with self.condition: + self.state.total_ticks += 1 + if stage == "idle": + self.state.total_idle_ticks += 1 + elif stage == "prepare": + self.state.total_prepare_dispatches += 1 + self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) + elif stage == "finalize": + self.state.total_finalize_dispatches += 1 + self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) + elif stage == "decode_dispatch": + self.state.total_decode_dispatches += 1 + elif stage == "decode_runtime": + self.state.total_decode_runtime_ticks += 1 + self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1) + self.state.last_stage = str(stage) + self.state.last_reason = str(reason) + self.state.last_policy_allowed = bool(policy_allowed) + self.state.last_observed_at = time.perf_counter() + + def build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + prepare_dispatcher_state = self.snapshot_prepare_state() + finalize_dispatcher_state = self.snapshot_finalize_state() + dispatcher_state = self.snapshot_dispatch_state() + active_requests = list(request_registry.get("active_requests", [])) + status_counts: Dict[str, int] = {} + for item in active_requests: + status = str(item.get("status", "UNKNOWN")) + status_counts[status] = status_counts.get(status, 0) + 1 + + worker_pending_jobs = int(worker_state.get("pending_jobs", 0)) + worker_decode_active_size = int(worker_state.get("running_requests", 0)) + worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0)) + worker_finalize_pending = int(worker_state.get("finalize_pending", 0)) + worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0)) + engine_decode_runtime_state = self.snapshot_decode_runtime_state() + engine_job_registry = self.snapshot_job_registry() + decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0)) + decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0)) + return { + "active_request_count": int(len(active_requests)), + "status_counts": status_counts, + "queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)), + "cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)), + "gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)), + "ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)), + "active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)), + "ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)), + "finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)), + "streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)), + "worker_pending_jobs": worker_pending_jobs, + "worker_decode_active_size": worker_decode_active_size, + "worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)), + "worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)), + "engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs, + "engine_decode_runtime_active_request_count": decode_runtime_active_size, + "engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)), + "engine_job_registry_count": int(engine_job_registry.get("job_count", 0)), + "worker_prepare_inflight": worker_prepare_inflight, + "worker_finalize_pending": worker_finalize_pending, + "worker_finalize_inflight": worker_finalize_inflight, + "engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)), + "engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)), + "engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)), + "decode_backlog": int( + decode_runtime_pending_jobs + decode_runtime_active_size + if bool(worker_state.get("engine_decode_control_enabled", False)) + else worker_pending_jobs + worker_decode_active_size + ), + } + + def build_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + counters = self.build_stage_counters(request_registry, worker_state) + config = self.policy_config.to_dict() + blocked_reasons: List[Dict[str, Any]] = [] + finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0)) + limit_checks = [ + ("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])), + ("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])), + ("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])), + ("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])), + ("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])), + ("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])), + ] + if bool(config["enabled"]): + for name, value, limit in limit_checks: + if limit > 0 and int(value) >= int(limit): + blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)}) + return { + "enabled": bool(config["enabled"]), + "allowed": (not bool(config["enabled"])) or not blocked_reasons, + "blocked_reasons": blocked_reasons, + "config": config, + "metrics": { + "active_request_count": int(counters["active_request_count"]), + "queued_request_count": int(counters["queued_request_count"]), + "ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]), + "active_decode_request_count": int(counters["active_decode_request_count"]), + "engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]), + "engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]), + "decode_backlog": int(counters["decode_backlog"]), + "prepare_inflight": int(counters["worker_prepare_inflight"]), + "finalize_pending": int(finalize_pending_total), + "engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)), + "finalize_inflight": int(counters["worker_finalize_inflight"]), + }, + "observed_at": time.perf_counter(), + } + + async def wait_for_policy_admission( + self, + *, + request_id: str | None, + timeout_sec: float | None, + ) -> tuple[float, Dict[str, Any]]: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + snapshot = self.build_policy_snapshot(request_registry, worker_state) + if not self.policy_config.enabled: + return 0.0, snapshot + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + while True: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + snapshot = self.build_policy_snapshot(request_registry, worker_state) + if snapshot["allowed"]: + wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0) + if request_id not in [None, ""]: + self.merge_request_state_profile( + str(request_id), + { + "engine_policy_wait_ms": float(wait_ms), + "engine_policy_snapshot": snapshot, + }, + ) + return wait_ms, snapshot + now = time.perf_counter() + if deadline is not None and now >= deadline: + blocked_summary = ", ".join( + f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", []) + ) + raise TimeoutError(f"engine policy admission timeout ({blocked_summary})") + await asyncio.sleep(self.policy_poll_s) + + def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) + prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0)) + finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) + decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) + decode_runtime_state = self.snapshot_decode_runtime_state() + worker_decode_has_work = bool(decode_runtime_state.get("has_work", False)) + worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False)) + worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) + worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) + prepare_age_ms = float(self.peek_queue_age_ms("prepare")) + finalize_age_ms = float(self.peek_queue_age_ms("finalize")) + decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) + decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) + policy_allowed = bool(policy_snapshot.get("allowed", True)) + if ( + worker_decode_control_enabled + and worker_decode_has_work + and policy_allowed + and decode_budget_remaining > 0 + and (worker_running_requests > 0 or worker_pending_jobs > 0) + ): + return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state + if ( + worker_decode_control_enabled + and worker_pending_jobs > 0 + and policy_allowed + and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms) + ): + return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state + if ( + decode_waiting > 0 + and policy_allowed + and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) + ): + return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state + if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): + return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): + return "finalize", "finalize_aging", policy_snapshot, worker_state + if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): + return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): + return "prepare", "prepare_aging", policy_snapshot, worker_state + if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: + return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state + if decode_waiting > 0 and policy_allowed: + return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state + if finalize_waiting > 0: + return "finalize", "finalize_fallback", policy_snapshot, worker_state + if prepare_waiting > 0: + return "prepare", "prepare_fallback", policy_snapshot, worker_state + return "idle", "no_pending_work", policy_snapshot, worker_state diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py new file mode 100644 index 00000000..111ca500 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py @@ -0,0 +1,381 @@ +from __future__ import annotations + +import asyncio +import threading +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Deque, Dict, Optional, Sequence + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState + + +@dataclass +class DefaultReferenceState: + ref_audio_path: str | None = None + updated_at: float = 0.0 + + +class ReferenceRegistry: + def __init__(self) -> None: + self._lock = threading.Lock() + self._state = DefaultReferenceState() + + def set_default(self, ref_audio_path: str) -> DefaultReferenceState: + with self._lock: + self._state = DefaultReferenceState(ref_audio_path=str(ref_audio_path), updated_at=time.time()) + return self._state + + def clear(self) -> DefaultReferenceState: + with self._lock: + self._state = DefaultReferenceState() + return self._state + + def get_default(self) -> DefaultReferenceState: + with self._lock: + return DefaultReferenceState( + ref_audio_path=self._state.ref_audio_path, + updated_at=self._state.updated_at, + ) + + +@dataclass +class ModelRegistryState: + t2s_weights_path: str + vits_weights_path: str + generation: int = 0 + t2s_generation: int = 0 + vits_generation: int = 0 + updated_at: float = field(default_factory=time.time) + + +class ModelRegistry: + def __init__(self, t2s_weights_path: str, vits_weights_path: str) -> None: + self._lock = threading.Lock() + self._state = ModelRegistryState( + t2s_weights_path=str(t2s_weights_path), + vits_weights_path=str(vits_weights_path), + ) + + def snapshot(self) -> ModelRegistryState: + with self._lock: + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + def mark_t2s_reload(self, weights_path: str) -> ModelRegistryState: + with self._lock: + self._state.t2s_weights_path = str(weights_path) + self._state.generation += 1 + self._state.t2s_generation += 1 + self._state.updated_at = time.time() + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + def mark_vits_reload(self, weights_path: str) -> ModelRegistryState: + with self._lock: + self._state.vits_weights_path = str(weights_path) + self._state.generation += 1 + self._state.vits_generation += 1 + self._state.updated_at = time.time() + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + +class EngineStatus: + NEW = "NEW" + QUEUED = "QUEUED" + VALIDATED = "VALIDATED" + CPU_PREPARING = "CPU_PREPARING" + GPU_PREPARING = "GPU_PREPARING" + READY_FOR_PREFILL = "READY_FOR_PREFILL" + ACTIVE_DECODE = "ACTIVE_DECODE" + READY_FOR_FINALIZE = "READY_FOR_FINALIZE" + FINALIZING = "FINALIZING" + STREAMING = "STREAMING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + +@dataclass +class EngineRequestState: + request_id: str + api_mode: str + backend: str + media_type: str + response_streaming: bool + submit_ts: float + deadline_ts: float | None = None + status: str = EngineStatus.NEW + updated_ts: float = 0.0 + error: str | None = None + finish_reason: str | None = None + meta: Dict[str, Any] = field(default_factory=dict) + profile: Dict[str, Any] = field(default_factory=dict) + lifecycle_timestamps: Dict[str, float] = field(default_factory=dict) + + def to_summary(self) -> Dict[str, Any]: + return { + "request_id": self.request_id, + "api_mode": self.api_mode, + "backend": self.backend, + "media_type": self.media_type, + "response_streaming": self.response_streaming, + "status": self.status, + "submit_ts": self.submit_ts, + "updated_ts": self.updated_ts, + "deadline_ts": self.deadline_ts, + "error": self.error, + "finish_reason": self.finish_reason, + "meta": dict(self.meta), + "profile": dict(self.profile), + "lifecycle_timestamps": dict(self.lifecycle_timestamps), + } + + +class EngineRequestRegistry: + def __init__(self, recent_limit: int) -> None: + self.lock = threading.Lock() + self.active_requests: Dict[str, EngineRequestState] = {} + self.recent_requests: Deque[EngineRequestState] = deque() + self.recent_limit = max(1, int(recent_limit)) + + def register( + self, + *, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + now = time.perf_counter() + state = EngineRequestState( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=bool(response_streaming), + submit_ts=now, + deadline_ts=deadline_ts, + updated_ts=now, + meta=dict(meta or {}), + lifecycle_timestamps={EngineStatus.NEW: now}, + ) + with self.lock: + self.active_requests[request_id] = state + return state + + def _move_to_recent_locked(self, state: EngineRequestState) -> None: + self.recent_requests.appendleft(state) + while len(self.recent_requests) > self.recent_limit: + self.recent_requests.pop() + + @staticmethod + def _apply_state_extra(state: EngineRequestState, extra: Optional[Dict[str, Any]]) -> None: + if not extra: + return + payload = dict(extra) + backend = payload.pop("backend", None) + if backend is not None: + state.backend = str(backend) + finish_reason = payload.pop("finish_reason", None) + if finish_reason is not None: + state.finish_reason = str(finish_reason) + error = payload.pop("error", None) + if error is not None: + state.error = str(error) + state.profile.update(payload) + + def update(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.get(request_id) + if state is None: + return + state.status = str(status) + state.updated_ts = now + state.lifecycle_timestamps[str(status)] = now + self._apply_state_extra(state, extra) + + def merge_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + if not extra: + return + now = time.perf_counter() + with self.lock: + state = self.active_requests.get(request_id) + if state is None: + for recent_state in self.recent_requests: + if recent_state.request_id == request_id: + state = recent_state + break + if state is None: + return + state.updated_ts = now + self._apply_state_extra(state, extra) + + def complete(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.COMPLETED + state.updated_ts = now + state.lifecycle_timestamps[EngineStatus.COMPLETED] = now + self._apply_state_extra(state, extra) + self._move_to_recent_locked(state) + + def fail(self, request_id: str, error: str) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.FAILED + state.updated_ts = now + state.error = str(error) + state.lifecycle_timestamps[EngineStatus.FAILED] = now + self._move_to_recent_locked(state) + + def snapshot(self) -> Dict[str, Any]: + with self.lock: + active = [state.to_summary() for state in self.active_requests.values()] + recent = [state.to_summary() for state in list(self.recent_requests)] + recent_limit = self.recent_limit + active.sort(key=lambda item: item["submit_ts"]) + return { + "active_count": len(active), + "recent_count": len(recent), + "recent_limit": recent_limit, + "active_requests": active, + "recent_requests": recent, + } + + def collect_summaries(self, request_ids: Sequence[str]) -> list[Dict[str, Any]]: + requested = set(request_ids) + results: list[Dict[str, Any]] = [] + with self.lock: + for state in self.active_requests.values(): + if state.request_id in requested: + results.append(state.to_summary()) + existing_ids = {item["request_id"] for item in results} + for state in self.recent_requests: + if state.request_id in requested and state.request_id not in existing_ids: + results.append(state.to_summary()) + results.sort(key=lambda item: item["request_id"]) + return results + + def has_active(self, request_id: str) -> bool: + with self.lock: + return request_id in self.active_requests + + +@dataclass +class SchedulerPendingJob: + request_id: str + state: T2SRequestState + done_event: threading.Event + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + enqueue_time: float + speed_factor: float + sample_steps: int + media_type: str + admission_wait_ms: float = 0.0 + engine_policy_wait_ms: float = 0.0 + engine_dispatch_wait_ms: float = 0.0 + prepare_wall_ms: float = 0.0 + prepare_profile_total_ms: float = 0.0 + first_schedule_time: float | None = None + prefill_ms: float = 0.0 + merge_ms: float = 0.0 + decode_ms: float = 0.0 + finalize_wait_ms: float = 0.0 + synth_ms: float = 0.0 + pack_ms: float = 0.0 + decode_steps: int = 0 + result_ready_time: float | None = None + result: dict | None = None + sample_rate: int | None = None + audio_data: np.ndarray | None = None + error: str | None = None + engine_request_id: str | None = None + + +class SchedulerJobRegistry: + def __init__(self, lock: threading.Lock | threading.RLock | threading.Condition) -> None: + self._lock = lock + self._job_map: Dict[str, SchedulerPendingJob] = {} + self._total_submitted = 0 + self._total_finished = 0 + + def register(self, job: SchedulerPendingJob, *, keep_job: bool = True) -> None: + with self._lock: + if keep_job: + self._job_map[job.request_id] = job + self._total_submitted += 1 + + def get(self, request_id: str) -> SchedulerPendingJob | None: + with self._lock: + return self._job_map.get(request_id) + + def pop(self, request_id: str) -> SchedulerPendingJob | None: + with self._lock: + return self._job_map.pop(request_id, None) + + def remove(self, request_id: str) -> None: + with self._lock: + self._job_map.pop(request_id, None) + + def mark_finished(self) -> None: + with self._lock: + self._total_finished += 1 + + def mark_finished_and_remove(self, request_id: str) -> None: + with self._lock: + self._job_map.pop(request_id, None) + self._total_finished += 1 + + def is_empty(self) -> bool: + with self._lock: + return not self._job_map + + def submitted_count(self) -> int: + with self._lock: + return int(self._total_submitted) + + def finished_count(self) -> int: + with self._lock: + return int(self._total_finished) + + def snapshot(self, max_request_ids: int = 32) -> Dict[str, Any]: + with self._lock: + request_ids = list(self._job_map.keys()) + return { + "job_count": int(len(request_ids)), + "request_ids": request_ids[: max(0, int(max_request_ids))], + "total_submitted": int(self._total_submitted), + "total_finished": int(self._total_finished), + } diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py new file mode 100644 index 00000000..db03a0c3 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py @@ -0,0 +1,334 @@ +from __future__ import annotations + +import asyncio +import threading +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Callable, Deque, Dict, List, Optional, Sequence + +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PreparedCpuStage +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch, T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import SchedulerPendingJob + + +class EngineTaskQueueOwner: + def __init__(self, completion_key: str = "total_completed") -> None: + self.condition = threading.Condition() + self.queue: Deque[Any] = deque() + self.total_submitted = 0 + self.total_completed = 0 + self.peak_waiting = 0 + self.completion_key = str(completion_key) + + def enqueue(self, item: Any) -> None: + with self.condition: + self.queue.append(item) + self.total_submitted += 1 + self.peak_waiting = max(self.peak_waiting, len(self.queue)) + self.condition.notify_all() + + def enqueue_many(self, items: Sequence[Any]) -> None: + if not items: + return + with self.condition: + for item in items: + self.queue.append(item) + self.total_submitted += len(items) + self.peak_waiting = max(self.peak_waiting, len(self.queue)) + self.condition.notify_all() + + def pop_left(self) -> Any | None: + with self.condition: + if not self.queue: + return None + return self.queue.popleft() + + def mark_completed(self, count: int = 1, *, notify: bool = False) -> None: + if count <= 0: + return + with self.condition: + self.total_completed += int(count) + if notify: + self.condition.notify_all() + + def has_items(self) -> bool: + with self.condition: + return bool(self.queue) + + def waiting_count(self) -> int: + with self.condition: + return int(len(self.queue)) + + def snapshot(self, *, max_request_ids: int = 16, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + with self.condition: + waiting_items = list(self.queue)[: max(0, int(max_request_ids))] + snapshot = { + "waiting_count": int(len(self.queue)), + "waiting_request_ids": [str(getattr(item, "request_id", "")) for item in waiting_items], + "peak_waiting": int(self.peak_waiting), + "total_submitted": int(self.total_submitted), + self.completion_key: int(self.total_completed), + } + if extra: + snapshot.update(dict(extra)) + return snapshot + + def peek_oldest_age_ms(self, timestamp_attr: str) -> float: + with self.condition: + if not self.queue: + return 0.0 + enqueue_time = float(getattr(self.queue[0], timestamp_attr)) + return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) + + def is_drained(self) -> bool: + with self.condition: + return not self.queue and self.total_submitted == self.total_completed + + def take_finalize_batch( + self, + *, + finalize_mode: str, + batch_max_items: int, + batch_wait_s: float, + use_vocoder: bool, + ) -> List[SchedulerFinalizeTask]: + with self.condition: + if not self.queue: + return [] + selected_tasks = [self.queue.popleft()] + if finalize_mode == "sync" or use_vocoder: + return selected_tasks + if batch_max_items <= 1: + return selected_tasks + first_task = selected_tasks[0] + oldest_age_s = max(0.0, time.perf_counter() - first_task.enqueued_time) + if len(self.queue) + 1 < batch_max_items and oldest_age_s < batch_wait_s: + self.queue.appendleft(first_task) + return [] + while len(selected_tasks) < batch_max_items: + if not self.queue: + break + matched_index = None + for index, task in enumerate(self.queue): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is None: + break + selected_tasks.append(self.queue[matched_index]) + del self.queue[matched_index] + return selected_tasks + + +@dataclass +class EngineDecodeRuntimeState: + pending_jobs: int = 0 + pending_request_ids: List[str] = field(default_factory=list) + active_request_count: int = 0 + active_request_ids: List[str] = field(default_factory=list) + prefill_done: bool = False + decode_step_index_max: int = 0 + total_cycles: int = 0 + prefill_cycles: int = 0 + step_cycles: int = 0 + has_work: bool = False + last_event: str = "init" + updated_at: float = 0.0 + + +class EngineDecodeRuntimeOwner: + def __init__( + self, + *, + get_decode_runtime_counters: Callable[[], Dict[str, int]], + get_micro_batch_wait_s: Callable[[], float], + ) -> None: + self.get_decode_runtime_counters = get_decode_runtime_counters + self.get_micro_batch_wait_s = get_micro_batch_wait_s + self.condition = threading.Condition() + self.pending_jobs: Deque[SchedulerPendingJob] = deque() + self.active_batch: T2SActiveBatch | None = None + self.state_lock = threading.Lock() + self.state = EngineDecodeRuntimeState(updated_at=time.perf_counter()) + + @staticmethod + def summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + if active_batch is None: + return {} + decode_step_index_max = 0 + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0: + decode_step_index_max = int(active_batch.step_indices.max().item()) + return { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": int(decode_step_index_max), + } + + def snapshot_pending_queue_state(self) -> Dict[str, Any]: + with self.condition: + return { + "pending_jobs": int(len(self.pending_jobs)), + "pending_request_ids": [job.request_id for job in list(self.pending_jobs)[:32]], + } + + def enqueue_pending_job(self, job: SchedulerPendingJob) -> None: + with self.condition: + self.pending_jobs.append(job) + self.condition.notify_all() + self.refresh_state("engine_decode_pending_enqueue") + + def take_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs: + return [] + if wait_for_batch: + oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) + if (time.perf_counter() - oldest_enqueue_time) < self.get_micro_batch_wait_s(): + return [] + pending_jobs = list(self.pending_jobs) + self.pending_jobs.clear() + self.refresh_state("engine_decode_pending_dequeue") + return pending_jobs + + def pending_age_ms(self) -> float: + with self.condition: + if not self.pending_jobs: + return 0.0 + enqueue_time = float(self.pending_jobs[0].enqueue_time) + return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) + + def has_pending_jobs(self) -> bool: + with self.condition: + return bool(self.pending_jobs) + + def get_active_batch(self) -> T2SActiveBatch | None: + return self.active_batch + + def set_active_batch(self, active_batch: T2SActiveBatch | None) -> None: + self.active_batch = active_batch + + def active_batch_summary(self) -> Dict[str, Any]: + return self.summarize_active_batch(self.active_batch) + + def refresh_state(self, last_event: str) -> None: + pending_state = self.snapshot_pending_queue_state() + active_batch_summary = self.active_batch_summary() + worker_decode_counters = self.get_decode_runtime_counters() + with self.state_lock: + self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) + self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) + self.state.active_request_count = int(active_batch_summary.get("request_count", 0)) + self.state.active_request_ids = list(active_batch_summary.get("request_ids", []))[:32] + self.state.prefill_done = bool(active_batch_summary.get("prefill_done", False)) + self.state.decode_step_index_max = int(active_batch_summary.get("decode_step_index_max", 0)) + self.state.total_cycles = int(worker_decode_counters.get("total_cycles", 0)) + self.state.prefill_cycles = int(worker_decode_counters.get("prefill_cycles", 0)) + self.state.step_cycles = int(worker_decode_counters.get("step_cycles", 0)) + self.state.has_work = bool(pending_state.get("pending_jobs", 0) or active_batch_summary.get("request_count", 0)) + self.state.last_event = str(last_event) + self.state.updated_at = float(time.perf_counter()) + + def update_from_worker_snapshot(self, snapshot: Dict[str, Any]) -> None: + if not snapshot: + return + pending_state = self.snapshot_pending_queue_state() + with self.state_lock: + self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) + self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) + self.state.active_request_count = int(snapshot.get("active_request_count", 0)) + self.state.active_request_ids = list(snapshot.get("active_request_ids", []))[:32] + self.state.prefill_done = bool(snapshot.get("prefill_done", False)) + self.state.decode_step_index_max = int(snapshot.get("decode_step_index_max", 0)) + self.state.total_cycles = int(snapshot.get("total_cycles", 0)) + self.state.prefill_cycles = int(snapshot.get("prefill_cycles", 0)) + self.state.step_cycles = int(snapshot.get("step_cycles", 0)) + self.state.has_work = bool( + pending_state.get("pending_jobs", 0) + or snapshot.get("active_request_count", 0) + or snapshot.get("has_work", False) + ) + self.state.last_event = str(snapshot.get("last_event", "unknown")) + self.state.updated_at = float(snapshot.get("updated_at", time.perf_counter())) + + def snapshot_state(self) -> Dict[str, Any]: + pending_state = self.snapshot_pending_queue_state() + active_batch_summary = self.active_batch_summary() + worker_decode_counters = self.get_decode_runtime_counters() + with self.state_lock: + return { + "pending_jobs": int(pending_state.get("pending_jobs", self.state.pending_jobs)), + "pending_request_ids": list(pending_state.get("pending_request_ids", self.state.pending_request_ids)), + "active_request_count": int(active_batch_summary.get("request_count", self.state.active_request_count)), + "active_request_ids": list(active_batch_summary.get("request_ids", self.state.active_request_ids)), + "prefill_done": bool(active_batch_summary.get("prefill_done", self.state.prefill_done)), + "decode_step_index_max": int(active_batch_summary.get("decode_step_index_max", self.state.decode_step_index_max)), + "total_cycles": int(worker_decode_counters.get("total_cycles", 0)), + "prefill_cycles": int(worker_decode_counters.get("prefill_cycles", 0)), + "step_cycles": int(worker_decode_counters.get("step_cycles", 0)), + "has_work": bool( + pending_state.get("pending_jobs", 0) + or active_batch_summary.get("request_count", self.state.active_request_count) + or self.state.has_work + ), + "last_event": str(self.state.last_event), + "updated_at": float(self.state.updated_at), + } + + +@dataclass +class SchedulerFinalizeTask: + request_id: str + item: T2SFinishedItem + enqueued_time: float + + +@dataclass +class EngineDispatchTask: + request_id: str + state: T2SRequestState + speed_factor: float + sample_steps: int + media_type: str + prepare_wall_ms: float + prepare_profile_total_ms: float + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + engine_request_id: str | None + timeout_sec: float | None + enqueue_time: float + worker_job: SchedulerPendingJob | None = None + engine_policy_wait_ms: float = 0.0 + engine_dispatch_wait_ms: float = 0.0 + engine_policy_snapshot: Dict[str, Any] | None = None + error: str | None = None + + +@dataclass +class EngineGpuPrepareTask: + request_id: str + cpu_stage: PreparedCpuStage + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + engine_request_id: str | None + enqueue_time: float + queue_wait_ms: float = 0.0 + error: str | None = None + + +@dataclass +class EngineFinalizeQueueState: + waiting_count: int + waiting_request_ids: List[str] + peak_waiting: int + total_submitted: int + total_completed: int + + +@dataclass +class RuntimeStateCallbacks: + update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None + complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None + fail: Callable[[str, str], None] | None = None + decode_runtime_update: Callable[[Dict[str, Any]], None] | None = None diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py index 3a124f4e..ac1adac5 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py @@ -1,1150 +1,63 @@ -from __future__ import annotations - -import asyncio -import os -import threading -import time -import uuid -from collections import deque -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Callable, Deque, Dict, List, Optional, Sequence, Tuple, Union - -import numpy as np -import torch - -from GPT_SoVITS.TTS_infer_pack.TTS import TTS -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState - - -@dataclass -class RuntimeControlCallbacks: - restart: Callable[[], None] | None = None - exit: Callable[[], None] | None = None - - -@dataclass -class DefaultReferenceState: - ref_audio_path: str | None = None - updated_at: float = 0.0 - - -class ReferenceRegistry: - def __init__(self) -> None: - self._lock = threading.Lock() - self._state = DefaultReferenceState() - - def set_default(self, ref_audio_path: str) -> DefaultReferenceState: - with self._lock: - self._state = DefaultReferenceState(ref_audio_path=str(ref_audio_path), updated_at=time.time()) - return self._state - - def clear(self) -> DefaultReferenceState: - with self._lock: - self._state = DefaultReferenceState() - return self._state - - def get_default(self) -> DefaultReferenceState: - with self._lock: - return DefaultReferenceState( - ref_audio_path=self._state.ref_audio_path, - updated_at=self._state.updated_at, - ) - - -@dataclass -class ModelRegistryState: - t2s_weights_path: str - vits_weights_path: str - generation: int = 0 - t2s_generation: int = 0 - vits_generation: int = 0 - updated_at: float = field(default_factory=time.time) - - -class ModelRegistry: - def __init__(self, t2s_weights_path: str, vits_weights_path: str) -> None: - self._lock = threading.Lock() - self._state = ModelRegistryState( - t2s_weights_path=str(t2s_weights_path), - vits_weights_path=str(vits_weights_path), - ) - - def snapshot(self) -> ModelRegistryState: - with self._lock: - return ModelRegistryState( - t2s_weights_path=self._state.t2s_weights_path, - vits_weights_path=self._state.vits_weights_path, - generation=self._state.generation, - t2s_generation=self._state.t2s_generation, - vits_generation=self._state.vits_generation, - updated_at=self._state.updated_at, - ) - - def mark_t2s_reload(self, weights_path: str) -> ModelRegistryState: - with self._lock: - self._state.t2s_weights_path = str(weights_path) - self._state.generation += 1 - self._state.t2s_generation += 1 - self._state.updated_at = time.time() - return ModelRegistryState( - t2s_weights_path=self._state.t2s_weights_path, - vits_weights_path=self._state.vits_weights_path, - generation=self._state.generation, - t2s_generation=self._state.t2s_generation, - vits_generation=self._state.vits_generation, - updated_at=self._state.updated_at, - ) - - def mark_vits_reload(self, weights_path: str) -> ModelRegistryState: - with self._lock: - self._state.vits_weights_path = str(weights_path) - self._state.generation += 1 - self._state.vits_generation += 1 - self._state.updated_at = time.time() - return ModelRegistryState( - t2s_weights_path=self._state.t2s_weights_path, - vits_weights_path=self._state.vits_weights_path, - generation=self._state.generation, - t2s_generation=self._state.t2s_generation, - vits_generation=self._state.vits_generation, - updated_at=self._state.updated_at, - ) - - -@dataclass -class DirectTTSExecution: - media_type: str - streaming: bool - audio_generator: Optional[Generator[bytes, None, None]] = None - audio_bytes: Optional[bytes] = None - request_id: Optional[str] = None - - -@dataclass -class NormalizedEngineRequest: - request_id: str - text: str - text_lang: str - ref_audio_path: str - prompt_lang: str - prompt_text: str = "" - aux_ref_audio_paths: List[str] | None = None - top_k: int = 15 - top_p: float = 1.0 - temperature: float = 1.0 - repetition_penalty: float = 1.35 - early_stop_num: int = -1 - ready_step: int = 0 - text_split_method: str = "cut5" - batch_size: int = 1 - batch_threshold: float = 0.75 - split_bucket: bool = False - speed_factor: float = 1.0 - fragment_interval: float = 0.3 - seed: int = -1 - media_type: str = "wav" - streaming_mode: bool | int = False - return_fragment: bool = False - fixed_length_chunk: bool = False - response_streaming: bool = False - parallel_infer: bool = False - sample_steps: int = 32 - super_sampling: bool = False - overlap_length: int = 2 - min_chunk_length: int = 16 - timeout_sec: float | None = None - - def to_payload(self) -> Dict[str, Any]: - return { - "request_id": self.request_id, - "text": self.text, - "text_lang": self.text_lang, - "ref_audio_path": self.ref_audio_path, - "aux_ref_audio_paths": list(self.aux_ref_audio_paths) if self.aux_ref_audio_paths else None, - "prompt_text": self.prompt_text, - "prompt_lang": self.prompt_lang, - "top_k": self.top_k, - "top_p": self.top_p, - "temperature": self.temperature, - "text_split_method": self.text_split_method, - "batch_size": self.batch_size, - "batch_threshold": self.batch_threshold, - "speed_factor": self.speed_factor, - "split_bucket": self.split_bucket, - "fragment_interval": self.fragment_interval, - "seed": self.seed, - "media_type": self.media_type, - "streaming_mode": self.streaming_mode, - "return_fragment": self.return_fragment, - "fixed_length_chunk": self.fixed_length_chunk, - "response_streaming": self.response_streaming, - "parallel_infer": self.parallel_infer, - "repetition_penalty": self.repetition_penalty, - "sample_steps": self.sample_steps, - "super_sampling": self.super_sampling, - "overlap_length": self.overlap_length, - "min_chunk_length": self.min_chunk_length, - "early_stop_num": self.early_stop_num, - "ready_step": self.ready_step, - "timeout_sec": self.timeout_sec, - } - - def to_scheduler_spec(self) -> SchedulerRequestSpec: - return SchedulerRequestSpec( - request_id=self.request_id, - ref_audio_path=Path(self.ref_audio_path), - prompt_text=self.prompt_text, - prompt_lang=self.prompt_lang, - text=self.text, - text_lang=self.text_lang, - top_k=self.top_k, - top_p=self.top_p, - temperature=self.temperature, - repetition_penalty=self.repetition_penalty, - early_stop_num=self.early_stop_num, - ready_step=self.ready_step, - ) - - -@dataclass -class SchedulerDebugExecution: - payload: Dict[str, Any] - - -@dataclass -class SchedulerSubmitExecution: - audio_bytes: bytes - media_type: str - headers: Dict[str, str] - - -@dataclass -class EnginePolicyConfig: - enabled: bool = True - poll_wait_ms: float = 5.0 - decode_backlog_soft_max: int = 0 - finalize_pending_soft_max: int = 0 - prepare_inflight_soft_max: int = 0 - active_decode_soft_max: int = 0 - ready_for_prefill_soft_max: int = 0 - active_request_soft_max: int = 0 - - def to_dict(self) -> Dict[str, Any]: - return { - "enabled": bool(self.enabled), - "poll_wait_ms": float(self.poll_wait_ms), - "decode_backlog_soft_max": int(self.decode_backlog_soft_max), - "finalize_pending_soft_max": int(self.finalize_pending_soft_max), - "prepare_inflight_soft_max": int(self.prepare_inflight_soft_max), - "active_decode_soft_max": int(self.active_decode_soft_max), - "ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max), - "active_request_soft_max": int(self.active_request_soft_max), - } - - -@dataclass -class EngineArbiterConfig: - poll_wait_ms: float = 5.0 - decode_burst: int = 4 - prepare_aging_ms: float = 10.0 - finalize_aging_ms: float = 10.0 - - def to_dict(self) -> Dict[str, Any]: - return { - "poll_wait_ms": float(self.poll_wait_ms), - "decode_burst": int(self.decode_burst), - "prepare_aging_ms": float(self.prepare_aging_ms), - "finalize_aging_ms": float(self.finalize_aging_ms), - } - - -class EngineStatus: - NEW = "NEW" - QUEUED = "QUEUED" - VALIDATED = "VALIDATED" - CPU_PREPARING = "CPU_PREPARING" - GPU_PREPARING = "GPU_PREPARING" - READY_FOR_PREFILL = "READY_FOR_PREFILL" - ACTIVE_DECODE = "ACTIVE_DECODE" - READY_FOR_FINALIZE = "READY_FOR_FINALIZE" - FINALIZING = "FINALIZING" - STREAMING = "STREAMING" - COMPLETED = "COMPLETED" - FAILED = "FAILED" - - -@dataclass -class EngineRequestState: - request_id: str - api_mode: str - backend: str - media_type: str - response_streaming: bool - submit_ts: float - deadline_ts: float | None = None - status: str = EngineStatus.NEW - updated_ts: float = 0.0 - error: str | None = None - finish_reason: str | None = None - meta: Dict[str, Any] = field(default_factory=dict) - profile: Dict[str, Any] = field(default_factory=dict) - lifecycle_timestamps: Dict[str, float] = field(default_factory=dict) - - def to_summary(self) -> Dict[str, Any]: - return { - "request_id": self.request_id, - "api_mode": self.api_mode, - "backend": self.backend, - "media_type": self.media_type, - "response_streaming": self.response_streaming, - "status": self.status, - "submit_ts": self.submit_ts, - "updated_ts": self.updated_ts, - "deadline_ts": self.deadline_ts, - "error": self.error, - "finish_reason": self.finish_reason, - "meta": dict(self.meta), - "profile": dict(self.profile), - "lifecycle_timestamps": dict(self.lifecycle_timestamps), - } - - -class EngineRequestRegistry: - def __init__(self, recent_limit: int) -> None: - self.lock = threading.Lock() - self.active_requests: Dict[str, EngineRequestState] = {} - self.recent_requests: Deque[EngineRequestState] = deque() - self.recent_limit = max(1, int(recent_limit)) - - def register( - self, - *, - request_id: str, - api_mode: str, - backend: str, - media_type: str, - response_streaming: bool, - deadline_ts: float | None = None, - meta: Optional[Dict[str, Any]] = None, - ) -> EngineRequestState: - now = time.perf_counter() - state = EngineRequestState( - request_id=request_id, - api_mode=api_mode, - backend=backend, - media_type=media_type, - response_streaming=bool(response_streaming), - submit_ts=now, - deadline_ts=deadline_ts, - updated_ts=now, - meta=dict(meta or {}), - lifecycle_timestamps={EngineStatus.NEW: now}, - ) - with self.lock: - self.active_requests[request_id] = state - return state - - def _move_to_recent_locked(self, state: EngineRequestState) -> None: - self.recent_requests.appendleft(state) - while len(self.recent_requests) > self.recent_limit: - self.recent_requests.pop() - - @staticmethod - def _apply_state_extra(state: EngineRequestState, extra: Optional[Dict[str, Any]]) -> None: - if not extra: - return - payload = dict(extra) - backend = payload.pop("backend", None) - if backend is not None: - state.backend = str(backend) - finish_reason = payload.pop("finish_reason", None) - if finish_reason is not None: - state.finish_reason = str(finish_reason) - error = payload.pop("error", None) - if error is not None: - state.error = str(error) - state.profile.update(payload) - - def update(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: - now = time.perf_counter() - with self.lock: - state = self.active_requests.get(request_id) - if state is None: - return - state.status = str(status) - state.updated_ts = now - state.lifecycle_timestamps[str(status)] = now - self._apply_state_extra(state, extra) - - def merge_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - if not extra: - return - now = time.perf_counter() - with self.lock: - state = self.active_requests.get(request_id) - if state is None: - for recent_state in self.recent_requests: - if recent_state.request_id == request_id: - state = recent_state - break - if state is None: - return - state.updated_ts = now - self._apply_state_extra(state, extra) - - def complete(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - now = time.perf_counter() - with self.lock: - state = self.active_requests.pop(request_id, None) - if state is None: - return - state.status = EngineStatus.COMPLETED - state.updated_ts = now - state.lifecycle_timestamps[EngineStatus.COMPLETED] = now - self._apply_state_extra(state, extra) - self._move_to_recent_locked(state) - - def fail(self, request_id: str, error: str) -> None: - now = time.perf_counter() - with self.lock: - state = self.active_requests.pop(request_id, None) - if state is None: - return - state.status = EngineStatus.FAILED - state.updated_ts = now - state.error = str(error) - state.lifecycle_timestamps[EngineStatus.FAILED] = now - self._move_to_recent_locked(state) - - def snapshot(self) -> Dict[str, Any]: - with self.lock: - active = [state.to_summary() for state in self.active_requests.values()] - recent = [state.to_summary() for state in list(self.recent_requests)] - recent_limit = self.recent_limit - active.sort(key=lambda item: item["submit_ts"]) - return { - "active_count": len(active), - "recent_count": len(recent), - "recent_limit": recent_limit, - "active_requests": active, - "recent_requests": recent, - } - - def collect_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: - requested = set(request_ids) - results: List[Dict[str, Any]] = [] - with self.lock: - for state in self.active_requests.values(): - if state.request_id in requested: - results.append(state.to_summary()) - existing_ids = {item["request_id"] for item in results} - for state in self.recent_requests: - if state.request_id in requested and state.request_id not in existing_ids: - results.append(state.to_summary()) - results.sort(key=lambda item: item["request_id"]) - return results - - def has_active(self, request_id: str) -> bool: - with self.lock: - return request_id in self.active_requests - - -@dataclass -class SchedulerPendingJob: - request_id: str - state: T2SRequestState - done_event: threading.Event - done_loop: asyncio.AbstractEventLoop | None - done_future: asyncio.Future | None - enqueue_time: float - speed_factor: float - sample_steps: int - media_type: str - admission_wait_ms: float = 0.0 - engine_policy_wait_ms: float = 0.0 - engine_dispatch_wait_ms: float = 0.0 - prepare_wall_ms: float = 0.0 - prepare_profile_total_ms: float = 0.0 - first_schedule_time: float | None = None - prefill_ms: float = 0.0 - merge_ms: float = 0.0 - decode_ms: float = 0.0 - finalize_wait_ms: float = 0.0 - synth_ms: float = 0.0 - pack_ms: float = 0.0 - decode_steps: int = 0 - result_ready_time: float | None = None - result: dict | None = None - sample_rate: int | None = None - audio_data: np.ndarray | None = None - error: str | None = None - engine_request_id: str | None = None - - -class SchedulerJobRegistry: - def __init__(self, lock: threading.Lock | threading.RLock | threading.Condition) -> None: - self._lock = lock - self._job_map: Dict[str, SchedulerPendingJob] = {} - self._total_submitted = 0 - self._total_finished = 0 - - def register(self, job: SchedulerPendingJob, *, keep_job: bool = True) -> None: - with self._lock: - if keep_job: - self._job_map[job.request_id] = job - self._total_submitted += 1 - - def get(self, request_id: str) -> SchedulerPendingJob | None: - with self._lock: - return self._job_map.get(request_id) - - def pop(self, request_id: str) -> SchedulerPendingJob | None: - with self._lock: - return self._job_map.pop(request_id, None) - - def remove(self, request_id: str) -> None: - with self._lock: - self._job_map.pop(request_id, None) - - def mark_finished(self) -> None: - with self._lock: - self._total_finished += 1 - - def mark_finished_and_remove(self, request_id: str) -> None: - with self._lock: - self._job_map.pop(request_id, None) - self._total_finished += 1 - - def is_empty(self) -> bool: - with self._lock: - return not self._job_map - - def submitted_count(self) -> int: - with self._lock: - return int(self._total_submitted) - - def finished_count(self) -> int: - with self._lock: - return int(self._total_finished) - - def snapshot(self, max_request_ids: int = 32) -> Dict[str, Any]: - with self._lock: - request_ids = list(self._job_map.keys()) - return { - "job_count": int(len(request_ids)), - "request_ids": request_ids[: max(0, int(max_request_ids))], - "total_submitted": int(self._total_submitted), - "total_finished": int(self._total_finished), - } - - -class EngineTaskQueueOwner: - def __init__(self, completion_key: str = "total_completed") -> None: - self.condition = threading.Condition() - self.queue: Deque[Any] = deque() - self.total_submitted = 0 - self.total_completed = 0 - self.peak_waiting = 0 - self.completion_key = str(completion_key) - - def enqueue(self, item: Any) -> None: - with self.condition: - self.queue.append(item) - self.total_submitted += 1 - self.peak_waiting = max(self.peak_waiting, len(self.queue)) - self.condition.notify_all() - - def enqueue_many(self, items: Sequence[Any]) -> None: - if not items: - return - with self.condition: - for item in items: - self.queue.append(item) - self.total_submitted += len(items) - self.peak_waiting = max(self.peak_waiting, len(self.queue)) - self.condition.notify_all() - - def pop_left(self) -> Any | None: - with self.condition: - if not self.queue: - return None - return self.queue.popleft() - - def mark_completed(self, count: int = 1, *, notify: bool = False) -> None: - if count <= 0: - return - with self.condition: - self.total_completed += int(count) - if notify: - self.condition.notify_all() - - def has_items(self) -> bool: - with self.condition: - return bool(self.queue) - - def waiting_count(self) -> int: - with self.condition: - return int(len(self.queue)) - - def snapshot(self, *, max_request_ids: int = 16, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - with self.condition: - waiting_items = list(self.queue)[: max(0, int(max_request_ids))] - snapshot = { - "waiting_count": int(len(self.queue)), - "waiting_request_ids": [str(getattr(item, "request_id", "")) for item in waiting_items], - "peak_waiting": int(self.peak_waiting), - "total_submitted": int(self.total_submitted), - self.completion_key: int(self.total_completed), - } - if extra: - snapshot.update(dict(extra)) - return snapshot - - def peek_oldest_age_ms(self, timestamp_attr: str) -> float: - with self.condition: - if not self.queue: - return 0.0 - enqueue_time = float(getattr(self.queue[0], timestamp_attr)) - return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) - - def is_drained(self) -> bool: - with self.condition: - return not self.queue and self.total_submitted == self.total_completed - - def take_finalize_batch( - self, - *, - finalize_mode: str, - batch_max_items: int, - batch_wait_s: float, - use_vocoder: bool, - ) -> List[SchedulerFinalizeTask]: - with self.condition: - if not self.queue: - return [] - selected_tasks = [self.queue.popleft()] - if finalize_mode == "sync" or use_vocoder: - return selected_tasks - if batch_max_items <= 1: - return selected_tasks - first_task = selected_tasks[0] - oldest_age_s = max(0.0, time.perf_counter() - first_task.enqueued_time) - if len(self.queue) + 1 < batch_max_items and oldest_age_s < batch_wait_s: - self.queue.appendleft(first_task) - return [] - while len(selected_tasks) < batch_max_items: - if not self.queue: - break - matched_index = None - for index, task in enumerate(self.queue): - if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: - matched_index = index - break - if matched_index is None: - break - selected_tasks.append(self.queue[matched_index]) - del self.queue[matched_index] - return selected_tasks - - -class EnginePolicyArbiterController: - def __init__( - self, - *, - policy_config: EnginePolicyConfig, - arbiter_config: EngineArbiterConfig, - snapshot_request_registry: Callable[[], Dict[str, Any]], - get_worker_state: Callable[[], Dict[str, Any]], - snapshot_prepare_state: Callable[[], Dict[str, Any]], - snapshot_finalize_state: Callable[[], Dict[str, Any]], - snapshot_dispatch_state: Callable[[], Dict[str, Any]], - snapshot_decode_runtime_state: Callable[[], Dict[str, Any]], - snapshot_job_registry: Callable[[], Dict[str, Any]], - peek_queue_age_ms: Callable[[str], float], - merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], - ) -> None: - self.policy_config = policy_config - self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0) - self.arbiter_config = arbiter_config - self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0) - self.condition = threading.Condition() - self.state = EngineArbiterState( - decode_budget_remaining=int(self.arbiter_config.decode_burst), - last_observed_at=time.perf_counter(), - ) - self.snapshot_request_registry = snapshot_request_registry - self.get_worker_state = get_worker_state - self.snapshot_prepare_state = snapshot_prepare_state - self.snapshot_finalize_state = snapshot_finalize_state - self.snapshot_dispatch_state = snapshot_dispatch_state - self.snapshot_decode_runtime_state = snapshot_decode_runtime_state - self.snapshot_job_registry = snapshot_job_registry - self.peek_queue_age_ms = peek_queue_age_ms - self.merge_request_state_profile = merge_request_state_profile - - def snapshot_state(self) -> Dict[str, Any]: - with self.condition: - return { - "config": self.arbiter_config.to_dict(), - "total_ticks": int(self.state.total_ticks), - "total_idle_ticks": int(self.state.total_idle_ticks), - "total_prepare_dispatches": int(self.state.total_prepare_dispatches), - "total_decode_dispatches": int(self.state.total_decode_dispatches), - "total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks), - "total_finalize_dispatches": int(self.state.total_finalize_dispatches), - "decode_budget_remaining": int(self.state.decode_budget_remaining), - "last_stage": str(self.state.last_stage), - "last_reason": str(self.state.last_reason), - "last_policy_allowed": bool(self.state.last_policy_allowed), - "last_observed_at": float(self.state.last_observed_at), - } - - def notify(self) -> None: - with self.condition: - self.condition.notify_all() - - def wait(self) -> None: - with self.condition: - self.condition.wait(timeout=self.arbiter_poll_s) - - def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: - with self.condition: - self.state.total_ticks += 1 - if stage == "idle": - self.state.total_idle_ticks += 1 - elif stage == "prepare": - self.state.total_prepare_dispatches += 1 - self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) - elif stage == "finalize": - self.state.total_finalize_dispatches += 1 - self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) - elif stage == "decode_dispatch": - self.state.total_decode_dispatches += 1 - elif stage == "decode_runtime": - self.state.total_decode_runtime_ticks += 1 - self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1) - self.state.last_stage = str(stage) - self.state.last_reason = str(reason) - self.state.last_policy_allowed = bool(policy_allowed) - self.state.last_observed_at = time.perf_counter() - - def build_stage_counters( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - prepare_dispatcher_state = self.snapshot_prepare_state() - finalize_dispatcher_state = self.snapshot_finalize_state() - dispatcher_state = self.snapshot_dispatch_state() - active_requests = list(request_registry.get("active_requests", [])) - status_counts: Dict[str, int] = {} - for item in active_requests: - status = str(item.get("status", "UNKNOWN")) - status_counts[status] = status_counts.get(status, 0) + 1 - - worker_pending_jobs = int(worker_state.get("pending_jobs", 0)) - worker_decode_active_size = int(worker_state.get("running_requests", 0)) - worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0)) - worker_finalize_pending = int(worker_state.get("finalize_pending", 0)) - worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0)) - engine_decode_runtime_state = self.snapshot_decode_runtime_state() - engine_job_registry = self.snapshot_job_registry() - decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0)) - decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0)) - return { - "active_request_count": int(len(active_requests)), - "status_counts": status_counts, - "queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)), - "cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)), - "gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)), - "ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)), - "active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)), - "ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)), - "finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)), - "streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)), - "worker_pending_jobs": worker_pending_jobs, - "worker_decode_active_size": worker_decode_active_size, - "worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)), - "worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)), - "engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs, - "engine_decode_runtime_active_request_count": decode_runtime_active_size, - "engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)), - "engine_job_registry_count": int(engine_job_registry.get("job_count", 0)), - "worker_prepare_inflight": worker_prepare_inflight, - "worker_finalize_pending": worker_finalize_pending, - "worker_finalize_inflight": worker_finalize_inflight, - "engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)), - "engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)), - "engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)), - "decode_backlog": int( - decode_runtime_pending_jobs + decode_runtime_active_size - if bool(worker_state.get("engine_decode_control_enabled", False)) - else worker_pending_jobs + worker_decode_active_size - ), - } - - def build_policy_snapshot( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - counters = self.build_stage_counters(request_registry, worker_state) - config = self.policy_config.to_dict() - blocked_reasons: List[Dict[str, Any]] = [] - finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0)) - limit_checks = [ - ("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])), - ("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])), - ("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])), - ("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])), - ("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])), - ("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])), - ] - if bool(config["enabled"]): - for name, value, limit in limit_checks: - if limit > 0 and int(value) >= int(limit): - blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)}) - return { - "enabled": bool(config["enabled"]), - "allowed": (not bool(config["enabled"])) or not blocked_reasons, - "blocked_reasons": blocked_reasons, - "config": config, - "metrics": { - "active_request_count": int(counters["active_request_count"]), - "queued_request_count": int(counters["queued_request_count"]), - "ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]), - "active_decode_request_count": int(counters["active_decode_request_count"]), - "engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]), - "engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]), - "decode_backlog": int(counters["decode_backlog"]), - "prepare_inflight": int(counters["worker_prepare_inflight"]), - "finalize_pending": int(finalize_pending_total), - "engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)), - "finalize_inflight": int(counters["worker_finalize_inflight"]), - }, - "observed_at": time.perf_counter(), - } - - async def wait_for_policy_admission( - self, - *, - request_id: str | None, - timeout_sec: float | None, - ) -> tuple[float, Dict[str, Any]]: - request_registry = self.snapshot_request_registry() - worker_state = self.get_worker_state() - snapshot = self.build_policy_snapshot(request_registry, worker_state) - if not self.policy_config.enabled: - return 0.0, snapshot - start = time.perf_counter() - deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) - while True: - request_registry = self.snapshot_request_registry() - worker_state = self.get_worker_state() - snapshot = self.build_policy_snapshot(request_registry, worker_state) - if snapshot["allowed"]: - wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0) - if request_id not in [None, ""]: - self.merge_request_state_profile( - str(request_id), - { - "engine_policy_wait_ms": float(wait_ms), - "engine_policy_snapshot": snapshot, - }, - ) - return wait_ms, snapshot - now = time.perf_counter() - if deadline is not None and now >= deadline: - blocked_summary = ", ".join( - f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", []) - ) - raise TimeoutError(f"engine policy admission timeout ({blocked_summary})") - await asyncio.sleep(self.policy_poll_s) - - def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: - request_registry = self.snapshot_request_registry() - worker_state = self.get_worker_state() - policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) - prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0)) - finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) - decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) - decode_runtime_state = self.snapshot_decode_runtime_state() - worker_decode_has_work = bool(decode_runtime_state.get("has_work", False)) - worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False)) - worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) - worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) - prepare_age_ms = float(self.peek_queue_age_ms("prepare")) - finalize_age_ms = float(self.peek_queue_age_ms("finalize")) - decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) - decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) - policy_allowed = bool(policy_snapshot.get("allowed", True)) - if ( - worker_decode_control_enabled - and worker_decode_has_work - and policy_allowed - and decode_budget_remaining > 0 - and (worker_running_requests > 0 or worker_pending_jobs > 0) - ): - return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state - if ( - worker_decode_control_enabled - and worker_pending_jobs > 0 - and policy_allowed - and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms) - ): - return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state - if ( - decode_waiting > 0 - and policy_allowed - and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) - ): - return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state - if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): - return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state - if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): - return "finalize", "finalize_aging", policy_snapshot, worker_state - if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): - return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state - if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): - return "prepare", "prepare_aging", policy_snapshot, worker_state - if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: - return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state - if decode_waiting > 0 and policy_allowed: - return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state - if finalize_waiting > 0: - return "finalize", "finalize_fallback", policy_snapshot, worker_state - if prepare_waiting > 0: - return "prepare", "prepare_fallback", policy_snapshot, worker_state - return "idle", "no_pending_work", policy_snapshot, worker_state - - -class EngineDecodeRuntimeOwner: - def __init__( - self, - *, - get_decode_runtime_counters: Callable[[], Dict[str, int]], - get_micro_batch_wait_s: Callable[[], float], - ) -> None: - self.get_decode_runtime_counters = get_decode_runtime_counters - self.get_micro_batch_wait_s = get_micro_batch_wait_s - self.condition = threading.Condition() - self.pending_jobs: Deque[SchedulerPendingJob] = deque() - self.active_batch: T2SActiveBatch | None = None - self.state_lock = threading.Lock() - self.state = EngineDecodeRuntimeState(updated_at=time.perf_counter()) - - @staticmethod - def summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: - if active_batch is None: - return {} - decode_step_index_max = 0 - if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0: - decode_step_index_max = int(active_batch.step_indices.max().item()) - return { - "request_count": int(len(active_batch.request_ids)), - "request_ids": list(active_batch.request_ids), - "prefill_done": bool(active_batch.prefill_done), - "decode_step_index_max": int(decode_step_index_max), - } - - def snapshot_pending_queue_state(self) -> Dict[str, Any]: - with self.condition: - return { - "pending_jobs": int(len(self.pending_jobs)), - "pending_request_ids": [job.request_id for job in list(self.pending_jobs)[:32]], - } - - def enqueue_pending_job(self, job: SchedulerPendingJob) -> None: - with self.condition: - self.pending_jobs.append(job) - self.condition.notify_all() - self.refresh_state("engine_decode_pending_enqueue") - - def take_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - with self.condition: - if not self.pending_jobs: - return [] - if wait_for_batch: - oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) - if (time.perf_counter() - oldest_enqueue_time) < self.get_micro_batch_wait_s(): - return [] - pending_jobs = list(self.pending_jobs) - self.pending_jobs.clear() - self.refresh_state("engine_decode_pending_dequeue") - return pending_jobs - - def pending_age_ms(self) -> float: - with self.condition: - if not self.pending_jobs: - return 0.0 - enqueue_time = float(self.pending_jobs[0].enqueue_time) - return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) - - def has_pending_jobs(self) -> bool: - with self.condition: - return bool(self.pending_jobs) - - def get_active_batch(self) -> T2SActiveBatch | None: - return self.active_batch - - def set_active_batch(self, active_batch: T2SActiveBatch | None) -> None: - self.active_batch = active_batch - - def active_batch_summary(self) -> Dict[str, Any]: - return self.summarize_active_batch(self.active_batch) - - def refresh_state(self, last_event: str) -> None: - pending_state = self.snapshot_pending_queue_state() - active_batch_summary = self.active_batch_summary() - worker_decode_counters = self.get_decode_runtime_counters() - with self.state_lock: - self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) - self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) - self.state.active_request_count = int(active_batch_summary.get("request_count", 0)) - self.state.active_request_ids = list(active_batch_summary.get("request_ids", []))[:32] - self.state.prefill_done = bool(active_batch_summary.get("prefill_done", False)) - self.state.decode_step_index_max = int(active_batch_summary.get("decode_step_index_max", 0)) - self.state.total_cycles = int(worker_decode_counters.get("total_cycles", 0)) - self.state.prefill_cycles = int(worker_decode_counters.get("prefill_cycles", 0)) - self.state.step_cycles = int(worker_decode_counters.get("step_cycles", 0)) - self.state.has_work = bool(pending_state.get("pending_jobs", 0) or active_batch_summary.get("request_count", 0)) - self.state.last_event = str(last_event) - self.state.updated_at = float(time.perf_counter()) - - def update_from_worker_snapshot(self, snapshot: Dict[str, Any]) -> None: - if not snapshot: - return - pending_state = self.snapshot_pending_queue_state() - with self.state_lock: - self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) - self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) - self.state.active_request_count = int(snapshot.get("active_request_count", 0)) - self.state.active_request_ids = list(snapshot.get("active_request_ids", []))[:32] - self.state.prefill_done = bool(snapshot.get("prefill_done", False)) - self.state.decode_step_index_max = int(snapshot.get("decode_step_index_max", 0)) - self.state.total_cycles = int(snapshot.get("total_cycles", 0)) - self.state.prefill_cycles = int(snapshot.get("prefill_cycles", 0)) - self.state.step_cycles = int(snapshot.get("step_cycles", 0)) - self.state.has_work = bool( - pending_state.get("pending_jobs", 0) - or snapshot.get("active_request_count", 0) - or snapshot.get("has_work", False) - ) - self.state.last_event = str(snapshot.get("last_event", "unknown")) - self.state.updated_at = float(snapshot.get("updated_at", time.perf_counter())) - - def snapshot_state(self) -> Dict[str, Any]: - pending_state = self.snapshot_pending_queue_state() - active_batch_summary = self.active_batch_summary() - worker_decode_counters = self.get_decode_runtime_counters() - with self.state_lock: - return { - "pending_jobs": int(pending_state.get("pending_jobs", self.state.pending_jobs)), - "pending_request_ids": list(pending_state.get("pending_request_ids", self.state.pending_request_ids)), - "active_request_count": int(active_batch_summary.get("request_count", self.state.active_request_count)), - "active_request_ids": list(active_batch_summary.get("request_ids", self.state.active_request_ids)), - "prefill_done": bool(active_batch_summary.get("prefill_done", self.state.prefill_done)), - "decode_step_index_max": int( - active_batch_summary.get("decode_step_index_max", self.state.decode_step_index_max) - ), - "total_cycles": int(worker_decode_counters.get("total_cycles", 0)), - "prefill_cycles": int(worker_decode_counters.get("prefill_cycles", 0)), - "step_cycles": int(worker_decode_counters.get("step_cycles", 0)), - "has_work": bool( - pending_state.get("pending_jobs", 0) - or active_batch_summary.get("request_count", self.state.active_request_count) - or self.state.has_work - ), - "last_event": str(self.state.last_event), - "updated_at": float(self.state.updated_at), - } - -@dataclass -class SchedulerFinalizeTask: - request_id: str - item: T2SFinishedItem - enqueued_time: float - - -@dataclass -class EngineDispatchTask: - request_id: str - state: T2SRequestState - speed_factor: float - sample_steps: int - media_type: str - prepare_wall_ms: float - prepare_profile_total_ms: float - done_loop: asyncio.AbstractEventLoop | None - done_future: asyncio.Future | None - engine_request_id: str | None - timeout_sec: float | None - enqueue_time: float - worker_job: SchedulerPendingJob | None = None - engine_policy_wait_ms: float = 0.0 - engine_dispatch_wait_ms: float = 0.0 - engine_policy_snapshot: Dict[str, Any] | None = None - error: str | None = None - - -@dataclass -class EngineGpuPrepareTask: - request_id: str - cpu_stage: PreparedCpuStage - done_loop: asyncio.AbstractEventLoop | None - done_future: asyncio.Future | None - engine_request_id: str | None - enqueue_time: float - queue_wait_ms: float = 0.0 - error: str | None = None - - -@dataclass -class EngineFinalizeQueueState: - waiting_count: int - waiting_request_ids: List[str] - peak_waiting: int - total_submitted: int - total_completed: int - - -@dataclass -class EngineArbiterState: - total_ticks: int = 0 - total_idle_ticks: int = 0 - total_prepare_dispatches: int = 0 - total_decode_dispatches: int = 0 - total_decode_runtime_ticks: int = 0 - total_finalize_dispatches: int = 0 - decode_budget_remaining: int = 0 - last_stage: str = "idle" - last_reason: str = "init" - last_observed_at: float = 0.0 - last_policy_allowed: bool = True - - -@dataclass -class EngineDecodeRuntimeState: - pending_jobs: int = 0 - pending_request_ids: List[str] = field(default_factory=list) - active_request_count: int = 0 - active_request_ids: List[str] = field(default_factory=list) - prefill_done: bool = False - decode_step_index_max: int = 0 - total_cycles: int = 0 - prefill_cycles: int = 0 - step_cycles: int = 0 - has_work: bool = False - last_event: str = "init" - updated_at: float = 0.0 - - -@dataclass -class RuntimeStateCallbacks: - update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None - complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None - fail: Callable[[str, str], None] | None = None - decode_runtime_update: Callable[[Dict[str, Any]], None] | None = None - - +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_models import ( + DirectTTSExecution, + NormalizedEngineRequest, + RuntimeControlCallbacks, + SchedulerDebugExecution, + SchedulerSubmitExecution, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_policy import ( + EngineArbiterConfig, + EngineArbiterState, + EnginePolicyArbiterController, + EnginePolicyConfig, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import ( + DefaultReferenceState, + EngineRequestRegistry, + EngineRequestState, + EngineStatus, + ModelRegistry, + ModelRegistryState, + ReferenceRegistry, + SchedulerJobRegistry, + SchedulerPendingJob, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_runtime import ( + EngineDecodeRuntimeOwner, + EngineDecodeRuntimeState, + EngineDispatchTask, + EngineFinalizeQueueState, + EngineGpuPrepareTask, + EngineTaskQueueOwner, + RuntimeStateCallbacks, + SchedulerFinalizeTask, +) + +__all__ = [ + "DefaultReferenceState", + "DirectTTSExecution", + "EngineArbiterConfig", + "EngineArbiterState", + "EngineDecodeRuntimeOwner", + "EngineDecodeRuntimeState", + "EngineDispatchTask", + "EngineFinalizeQueueState", + "EngineGpuPrepareTask", + "EnginePolicyArbiterController", + "EnginePolicyConfig", + "EngineRequestRegistry", + "EngineRequestState", + "EngineStatus", + "EngineTaskQueueOwner", + "ModelRegistry", + "ModelRegistryState", + "NormalizedEngineRequest", + "ReferenceRegistry", + "RuntimeControlCallbacks", + "RuntimeStateCallbacks", + "SchedulerDebugExecution", + "SchedulerFinalizeTask", + "SchedulerJobRegistry", + "SchedulerPendingJob", + "SchedulerSubmitExecution", +] diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py index 04d9090f..934ccf52 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py @@ -1,905 +1,25 @@ from __future__ import annotations -import asyncio import os import threading -import time -from collections import deque -from typing import Any, Callable, Deque, Dict, List, Optional - -import numpy as np -import torch +from typing import Callable, List from GPT_SoVITS.TTS_infer_pack.TTS import TTS -from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState, decode_one_step, merge_active_batches, run_prefill_active_batch, run_scheduler_continuous -from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, RuntimeStateCallbacks, SchedulerFinalizeTask, SchedulerJobRegistry, SchedulerPendingJob - - -class WorkerPrepareExecutor: - def __init__( - self, - tts: TTS, - on_state_change: Callable[[], None] | None = None, - ) -> None: - self.coordinator = PrepareCoordinator(tts) - self.on_state_change = on_state_change - - def _notify_state_change(self) -> None: - if self.on_state_change is None: - return - try: - self.on_state_change() - except Exception: - pass - - def snapshot(self) -> Dict[str, int]: - return dict(self.coordinator.snapshot()) - - def get_max_inflight(self) -> int: - return int(self.coordinator.snapshot().get("max_inflight", 0)) - - def is_idle(self) -> bool: - return int(self.coordinator.snapshot().get("inflight", 0)) <= 0 - - async def prepare_state_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> tuple[T2SRequestState, float, float]: - try: - return await self.coordinator.prepare_state_profiled_async(spec, prepare_submit_at) - finally: - self._notify_state_change() - - async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: - results = await asyncio.gather( - *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] - ) - return [state for state, _, _ in results] - - async def prepare_cpu_stage_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> PreparedCpuStage: - try: - return await self.coordinator.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) - finally: - self._notify_state_change() - - async def prepare_gpu_stage_profiled_async( - self, - cpu_stage: PreparedCpuStage, - ) -> tuple[T2SRequestState, float, float]: - try: - return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) - finally: - self._notify_state_change() - - -class WorkerFinalizeExecutor: - def __init__( - self, - tts: TTS, - on_state_change: Callable[[], None] | None = None, - external_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, - ) -> None: - self.tts = tts - self.on_state_change = on_state_change - self.external_submit = external_submit - self.condition = threading.Condition() - self.pending_tasks: Deque[SchedulerFinalizeTask] = deque() - self.pending_peak = 0 - self.inflight = 0 - self.inflight_peak = 0 - self.worker_count = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) - self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() - self.batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) - self.batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) - - def _notify_state_change(self) -> None: - if self.on_state_change is None: - return - try: - self.on_state_change() - except Exception: - pass - - def get_worker_count(self) -> int: - return int(self.worker_count) - - def get_batch_policy(self) -> Dict[str, Any]: - return { - "finalize_mode": str(self.finalize_mode), - "finalize_batch_max_items": int(self.batch_max_items), - "finalize_batch_wait_s": float(self.batch_wait_s), - } - - def get_pending_count(self) -> int: - with self.condition: - return int(len(self.pending_tasks)) - - def snapshot(self) -> Dict[str, Any]: - with self.condition: - return { - "finalize_pending": int(len(self.pending_tasks)), - "finalize_pending_peak": int(self.pending_peak), - "finalize_inflight": int(self.inflight), - "finalize_inflight_peak": int(self.inflight_peak), - "finalize_workers": int(self.worker_count), - "finalize_mode": str(self.finalize_mode), - "finalize_batch_max_items": int(self.batch_max_items), - "finalize_batch_wait_ms": float(self.batch_wait_s * 1000.0), - } - - def is_idle(self) -> bool: - with self.condition: - return self.inflight <= 0 and not self.pending_tasks - - def enqueue_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: - if not tasks: - return - if self.external_submit is not None: - self.external_submit(tasks) - self._notify_state_change() - return - with self.condition: - for task in tasks: - self.pending_tasks.append(task) - self.pending_peak = max(self.pending_peak, len(self.pending_tasks)) - self.condition.notify_all() - self._notify_state_change() - - def begin_execution(self, task_count: int) -> None: - if task_count <= 0: - return - with self.condition: - self.inflight += int(task_count) - self.inflight_peak = max(self.inflight_peak, self.inflight) - self.condition.notify_all() - self._notify_state_change() - - def end_execution(self, task_count: int) -> None: - with self.condition: - self.inflight = max(0, self.inflight - int(task_count)) - self.condition.notify_all() - self._notify_state_change() - - def take_task_batch_blocking(self) -> List[SchedulerFinalizeTask]: - with self.condition: - while not self.pending_tasks: - self.condition.wait() - selected_tasks = [self.pending_tasks.popleft()] - if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: - self.inflight += len(selected_tasks) - self.inflight_peak = max(self.inflight_peak, self.inflight) - self._notify_state_change() - return selected_tasks - batch_deadline = time.perf_counter() + self.batch_wait_s - while len(selected_tasks) < self.batch_max_items: - if not self.pending_tasks: - remaining = batch_deadline - time.perf_counter() - if remaining <= 0: - break - self.condition.wait(timeout=remaining) - continue - first_task = selected_tasks[0] - matched_index = None - for index, task in enumerate(self.pending_tasks): - if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: - matched_index = index - break - if matched_index is not None: - selected_tasks.append(self.pending_tasks[matched_index]) - del self.pending_tasks[matched_index] - continue - remaining = batch_deadline - time.perf_counter() - if remaining <= 0: - break - self.condition.wait(timeout=remaining) - self.inflight += len(selected_tasks) - self.inflight_peak = max(self.inflight_peak, self.inflight) - self._notify_state_change() - return selected_tasks - - def _sync_device(self) -> None: - try: - device_str = str(self.tts.configs.device) - if device_str.startswith("cuda") and torch.cuda.is_available(): - torch.cuda.synchronize(self.tts.configs.device) - elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): - torch.mps.synchronize() - except Exception: - pass - - def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: - audio_fragment = self.tts.synthesize_audio_request_local( - semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), - phones=job.state.phones.detach().clone().unsqueeze(0), - prompt_semantic=job.state.prompt_semantic.detach().clone(), - prompt_phones=job.state.prompt_phones.detach().clone(), - refer_spec=( - job.state.refer_spec[0].detach().clone(), - None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), - ), - raw_audio=job.state.raw_audio.detach().clone(), - raw_sr=int(job.state.raw_sr), - speed=float(job.speed_factor), - sample_steps=int(job.sample_steps), - ) - output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] - return self.tts.audio_postprocess( - audio=[[audio_fragment]], - sr=int(output_sr), - batch_index_list=None, - speed_factor=float(job.speed_factor), - split_bucket=False, - fragment_interval=0.0, - super_sampling=False, - ) - - def _synthesize_finished_audio_batch( - self, - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], - ) -> List[tuple[int, np.ndarray]]: - semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] - phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] - refer_specs = [] - speeds = [] - sample_steps_list = [] - for job, _ in jobs_and_items: - refer_specs.append( - ( - job.state.refer_spec[0].detach().clone(), - None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), - ) - ) - speeds.append(float(job.speed_factor)) - sample_steps_list.append(int(job.sample_steps)) - audio_fragments = self.tts.synthesize_audio_requests_local_batched( - semantic_tokens_list=semantic_tokens_list, - phones_list=phones_list, - refer_specs=refer_specs, - speeds=speeds, - sample_steps_list=sample_steps_list, - ) - output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] - results: List[tuple[int, np.ndarray]] = [] - for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): - results.append( - self.tts.audio_postprocess( - audio=[[audio_fragment]], - sr=int(output_sr), - batch_index_list=None, - speed_factor=float(job.speed_factor), - split_bucket=False, - fragment_interval=0.0, - super_sampling=False, - ) - ) - return results - - def synthesize_finalize_jobs( - self, - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], - ) -> tuple[float, List[tuple[int, np.ndarray]]]: - if not jobs_and_items: - return 0.0, [] - self._sync_device() - synth_start = time.perf_counter() - if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: - job, item = jobs_and_items[0] - batch_results = [self._synthesize_finished_audio(job, item)] - else: - batch_results = self._synthesize_finished_audio_batch(jobs_and_items) - self._sync_device() - synth_ms = (time.perf_counter() - synth_start) * 1000.0 - return float(synth_ms), batch_results - - -class WorkerCompletionBridge: - def __init__(self, runtime_callbacks: RuntimeStateCallbacks | None = None) -> None: - self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() - - @staticmethod - def _resolve_done_future(job: SchedulerPendingJob) -> None: - future = job.done_future - if future is None or future.done(): - return - future.set_result(job) - - def notify_done_future(self, job: SchedulerPendingJob) -> None: - if job.done_loop is None or job.done_future is None: - return - try: - job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) - except RuntimeError: - pass - - def runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: - if request_id is None or self.runtime_callbacks.complete is None: - return - self.runtime_callbacks.complete(request_id, extra) - - def runtime_fail(self, request_id: str | None, error: str) -> None: - if request_id is None or self.runtime_callbacks.fail is None: - return - self.runtime_callbacks.fail(request_id, error) - - @staticmethod - def build_completed_job_result( - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - audio_data: np.ndarray, - finished_at: float | None = None, - ) -> Dict[str, Any]: - finished_at = float(time.perf_counter() if finished_at is None else finished_at) - queue_wait_ms = 0.0 - if job.first_schedule_time is not None: - queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) - worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) - worker_residual_ms = max( - 0.0, - worker_total_ms - - queue_wait_ms - - job.prefill_ms - - job.merge_ms - - job.decode_ms - - job.finalize_wait_ms - - job.synth_ms, - ) - worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) - job.sample_rate = int(sample_rate) - job.audio_data = audio_data - job.result_ready_time = finished_at - prepare_profile = dict(job.state.prepare_profile) - result = { - "request_id": item.request_id, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "finish_reason": item.finish_reason, - "decode_admission_wait_ms": float(job.admission_wait_ms), - "engine_policy_wait_ms": float(job.engine_policy_wait_ms), - "engine_dispatch_wait_ms": float(job.engine_dispatch_wait_ms), - "prepare_ms": job.prepare_wall_ms, - "prepare_wall_ms": job.prepare_wall_ms, - "prepare_profile_total_ms": job.prepare_profile_total_ms, - "prepare_profile": prepare_profile, - "queue_wait_ms": queue_wait_ms, - "prefill_ms": job.prefill_ms, - "merge_ms": job.merge_ms, - "decode_ms": job.decode_ms, - "finalize_wait_ms": job.finalize_wait_ms, - "synth_ms": job.synth_ms, - "worker_residual_ms": worker_residual_ms, - "worker_other_ms": worker_other_ms, - "worker_total_ms": worker_total_ms, - "decode_steps": int(job.decode_steps), - "sample_rate": int(sample_rate), - "media_type": job.media_type, - } - job.result = result - return result - - @staticmethod - def build_runtime_complete_payload( - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - ) -> Dict[str, Any]: - return { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "sample_rate": int(sample_rate), - "worker_profile": dict(job.result or {}), - } - - def complete_job( - self, - job: SchedulerPendingJob, - *, - runtime_request_id: str | None, - runtime_extra: Optional[Dict[str, Any]] = None, - remove_job: Callable[[], None] | None = None, - on_job_finished: Callable[[], None] | None = None, - notify_waiters: Callable[[], None] | None = None, - ) -> None: - job.done_event.set() - self.notify_done_future(job) - if remove_job is not None: - remove_job() - if on_job_finished is not None: - on_job_finished() - if notify_waiters is not None: - notify_waiters() - self.runtime_complete(runtime_request_id, runtime_extra) - - def fail_job( - self, - job: SchedulerPendingJob, - *, - error: str, - remove_job: Callable[[], None] | None = None, - on_job_finished: Callable[[], None] | None = None, - notify_waiters: Callable[[], None] | None = None, - ) -> None: - job.error = str(error) - job.done_event.set() - self.notify_done_future(job) - if remove_job is not None: - remove_job() - if on_job_finished is not None: - on_job_finished() - if notify_waiters is not None: - notify_waiters() - self.runtime_fail(job.engine_request_id, str(error)) - - def complete_finalize_task( - self, - *, - condition: threading.Condition, - job_registry: SchedulerJobRegistry, - job: SchedulerPendingJob, - item: T2SFinishedItem, - sample_rate: int, - audio_data: np.ndarray, - ) -> None: - runtime_extra: Optional[Dict[str, Any]] = None - with condition: - if job_registry.get(item.request_id) is not job: - return - self.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) - runtime_extra = self.build_runtime_complete_payload(job, item, sample_rate=sample_rate) - self.complete_job( - job, - runtime_request_id=job.engine_request_id, - runtime_extra=runtime_extra, - on_job_finished=lambda: job_registry.mark_finished_and_remove(item.request_id), - notify_waiters=condition.notify_all, - ) - - def fail_jobs( - self, - *, - condition: threading.Condition, - job_registry: SchedulerJobRegistry, - request_ids: List[str], - error: str, - ) -> None: - if not request_ids: - return - with condition: - for request_id in request_ids: - job = job_registry.get(request_id) - if job is None: - continue - self.fail_job( - job, - error=error, - on_job_finished=lambda rid=request_id: job_registry.mark_finished_and_remove(rid), - ) - condition.notify_all() - - -class WorkerDecodeExecutor: - def __init__(self, tts: TTS, max_steps: int) -> None: - self.tts = tts - self.max_steps = int(max_steps) - - def _sync_device(self) -> None: - try: - device_str = str(self.tts.configs.device) - if device_str.startswith("cuda") and torch.cuda.is_available(): - torch.cuda.synchronize(self.tts.configs.device) - elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): - torch.mps.synchronize() - except Exception: - pass - - def execute_prefill_merge( - self, - *, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], - add_prefill_time: Callable[[List[str], float], None] | None, - add_merge_time: Callable[[List[str], float], None] | None, - enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, - finalize_error: Callable[[List[str], str], None] | None, - ) -> Dict[str, Any]: - if not pending_jobs: - return { - "executed": False, - "active_batch": active_batch, - "pending_jobs": [], - "prefill_elapsed_s": 0.0, - "merge_elapsed_s": 0.0, - "finished_items": [], - "error": None, - "error_request_ids": [], - } - admitted_finished: List[T2SFinishedItem] = [] - prefill_elapsed_s = 0.0 - merge_elapsed_s = 0.0 - error: str | None = None - error_request_ids: List[str] = [] - try: - self._sync_device() - prefill_start = time.perf_counter() - mark_prefill_started(pending_jobs, prefill_start) - admitted_active_batch, admitted_finished = run_prefill_active_batch( - self.tts.t2s_model.model, - [job.state for job in pending_jobs], - max_steps=self.max_steps, - ) - self._sync_device() - prefill_elapsed_s = time.perf_counter() - prefill_start - if add_prefill_time is not None: - add_prefill_time([job.request_id for job in pending_jobs], prefill_elapsed_s) - if enqueue_finished is not None: - enqueue_finished(admitted_finished) - merge_start = time.perf_counter() - active_batch = merge_active_batches( - self.tts.t2s_model.model, - active_batch, - admitted_active_batch, - ) - merge_elapsed_s = time.perf_counter() - merge_start - if add_merge_time is not None: - add_merge_time( - [] if active_batch is None else list(active_batch.request_ids), - merge_elapsed_s, - ) - except Exception as exc: - error = str(exc) - error_request_ids = [job.request_id for job in pending_jobs] - if finalize_error is not None: - finalize_error(error_request_ids, error) - return { - "executed": True, - "active_batch": active_batch, - "pending_jobs": list(pending_jobs), - "prefill_elapsed_s": float(prefill_elapsed_s), - "merge_elapsed_s": float(merge_elapsed_s), - "finished_items": list(admitted_finished), - "error": error, - "error_request_ids": error_request_ids, - } - - def execute_decode_step( - self, - *, - active_batch: Optional[T2SActiveBatch], - add_decode_time: Callable[[List[str], float], None] | None, - enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, - finalize_error: Callable[[List[str], str], None] | None, - ) -> Dict[str, Any]: - if active_batch is None: - return { - "executed": False, - "active_batch": None, - "request_ids": [], - "decode_elapsed_s": 0.0, - "finished_items": [], - "error": None, - "error_request_ids": [], - } - active_request_ids: List[str] = [] - step_finished: List[T2SFinishedItem] = [] - decode_elapsed_s = 0.0 - error: str | None = None - error_request_ids: List[str] = [] - try: - active_request_ids = [state.request_id for state in active_batch.states] - self._sync_device() - decode_start = time.perf_counter() - active_batch, step_finished = decode_one_step( - self.tts.t2s_model.model, - active_batch, - max_steps=self.max_steps, - ) - self._sync_device() - decode_elapsed_s = time.perf_counter() - decode_start - if add_decode_time is not None: - add_decode_time(active_request_ids, decode_elapsed_s) - if enqueue_finished is not None: - enqueue_finished(step_finished) - except Exception as exc: - error = str(exc) - error_request_ids = list(active_request_ids) - if finalize_error is not None: - finalize_error(error_request_ids, error) - active_batch = None - return { - "executed": True, - "active_batch": active_batch, - "request_ids": active_request_ids, - "decode_elapsed_s": float(decode_elapsed_s), - "finished_items": list(step_finished), - "error": error, - "error_request_ids": error_request_ids, - } - - def execute_decode_cycle( - self, - *, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], - add_prefill_time: Callable[[List[str], float], None] | None, - add_merge_time: Callable[[List[str], float], None] | None, - add_decode_time: Callable[[List[str], float], None] | None, - enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, - finalize_error: Callable[[List[str], str], None] | None, - ) -> Dict[str, Any]: - result = { - "executed": False, - "prefill_merge_executed": False, - "decode_step_executed": False, - "active_batch": active_batch, - "prefill_phase": {}, - "decode_phase": {}, - } - prefill_phase = self.execute_prefill_merge( - pending_jobs=list(pending_jobs), - active_batch=result["active_batch"], - mark_prefill_started=mark_prefill_started, - add_prefill_time=add_prefill_time, - add_merge_time=add_merge_time, - enqueue_finished=enqueue_finished, - finalize_error=finalize_error, - ) - prefill_executed = bool(prefill_phase.get("executed", False)) - result["prefill_phase"] = prefill_phase - result["active_batch"] = prefill_phase.get("active_batch") - if prefill_executed: - result["executed"] = True - result["prefill_merge_executed"] = True - decode_phase = self.execute_decode_step( - active_batch=result["active_batch"], - add_decode_time=add_decode_time, - enqueue_finished=enqueue_finished, - finalize_error=finalize_error, - ) - decode_executed = bool(decode_phase.get("executed", False)) - result["decode_phase"] = decode_phase - result["active_batch"] = decode_phase.get("active_batch") - if decode_executed: - result["executed"] = True - result["decode_step_executed"] = True - return result - - -class WorkerDecodeLegacyShell: - def __init__(self, condition: threading.Condition, micro_batch_wait_s: float) -> None: - self.condition = condition - self.micro_batch_wait_s = float(micro_batch_wait_s) - self.pending_jobs: List[SchedulerPendingJob] = [] - self.active_batch: T2SActiveBatch | None = None - - @staticmethod - def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any] | None: - if active_batch is None: - return None - return { - "request_count": int(len(active_batch.request_ids)), - "request_ids": list(active_batch.request_ids), - "prefill_done": bool(active_batch.prefill_done), - "decode_step_index_max": ( - int(active_batch.step_indices.max().item()) - if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0 - else 0 - ), - } - - def current_backlog_locked(self) -> int: - running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids) - return int(len(self.pending_jobs) + running_requests) - - def enqueue_pending_job_locked(self, job: SchedulerPendingJob) -> None: - self.pending_jobs.append(job) - - def snapshot_locked(self) -> Dict[str, Any]: - active_batch_summary = self._summarize_active_batch(self.active_batch) - executor_local_pending_jobs = int(len(self.pending_jobs)) - executor_local_running_requests = 0 if self.active_batch is None else int(len(self.active_batch.request_ids)) - executor_local_has_work = bool(self.pending_jobs or self.active_batch is not None) - return { - "executor_local_pending_jobs": executor_local_pending_jobs, - "executor_local_running_requests": executor_local_running_requests, - "executor_local_has_work": executor_local_has_work, - "executor_local_active_batch": active_batch_summary, - } - - def is_idle_locked(self) -> bool: - return self.active_batch is None and not self.pending_jobs - - def take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - with self.condition: - if not self.pending_jobs and self.active_batch is None: - self.condition.wait(timeout=self.micro_batch_wait_s) - elif wait_for_batch and self.pending_jobs: - self.condition.wait(timeout=self.micro_batch_wait_s) - if not self.pending_jobs: - return [] - pending = list(self.pending_jobs) - self.pending_jobs.clear() - return pending - - def take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - with self.condition: - if not self.pending_jobs: - return [] - if wait_for_batch: - oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) - if (time.perf_counter() - oldest_enqueue_time) < self.micro_batch_wait_s: - return [] - pending = list(self.pending_jobs) - self.pending_jobs.clear() - return pending - - def has_decode_runtime_work(self) -> bool: - with self.condition: - return bool(self.pending_jobs or self.active_batch is not None) - - def build_runtime_summary_locked(self, *, total_cycles: int, prefill_cycles: int, step_cycles: int, last_event: str) -> Dict[str, Any]: - active_request_ids = [] if self.active_batch is None else list(self.active_batch.request_ids) - decode_step_index_max = 0 - prefill_done = False - if self.active_batch is not None: - prefill_done = bool(self.active_batch.prefill_done) - if self.active_batch.step_indices is not None and self.active_batch.step_indices.numel() > 0: - decode_step_index_max = int(self.active_batch.step_indices.max().item()) - return { - "pending_jobs": int(len(self.pending_jobs)), - "active_request_count": int(len(active_request_ids)), - "active_request_ids": active_request_ids[:32], - "prefill_done": bool(prefill_done), - "decode_step_index_max": int(decode_step_index_max), - "total_cycles": int(total_cycles), - "prefill_cycles": int(prefill_cycles), - "step_cycles": int(step_cycles), - "has_work": bool(self.pending_jobs or self.active_batch is not None), - "last_event": str(last_event), - "updated_at": float(time.perf_counter()), - } - - def run_prefill_merge_once_nonblocking( - self, - *, - external_pending_jobs: Optional[List[SchedulerPendingJob]], - external_active_batch: Optional[T2SActiveBatch], - execute_prefill_merge: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], - ) -> Dict[str, Any]: - pending_jobs = ( - list(external_pending_jobs) - if external_pending_jobs is not None - else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) - ) - active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch - result = execute_prefill_merge(pending_jobs, active_batch) - if external_pending_jobs is None: - with self.condition: - self.active_batch = result.get("active_batch") - self.condition.notify_all() - return result - - def run_decode_step_once_nonblocking( - self, - *, - external_active_batch: Optional[T2SActiveBatch], - execute_decode_step: Callable[[Optional[T2SActiveBatch]], Dict[str, Any]], - ) -> Dict[str, Any]: - active_batch = self.active_batch if external_active_batch is None else external_active_batch - result = execute_decode_step(active_batch) - if external_active_batch is None: - with self.condition: - self.active_batch = result.get("active_batch") - self.condition.notify_all() - return result - - def run_decode_cycle_nonblocking( - self, - *, - external_pending_jobs: Optional[List[SchedulerPendingJob]], - external_active_batch: Optional[T2SActiveBatch], - execute_decode_cycle: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], - on_cycle_executed: Callable[[Dict[str, Any]], None] | None, - ) -> Dict[str, Any]: - pending_jobs = ( - list(external_pending_jobs) - if external_pending_jobs is not None - else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) - ) - active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch - result = execute_decode_cycle(pending_jobs, active_batch) - if external_pending_jobs is None: - with self.condition: - self.active_batch = result.get("active_batch") - self.condition.notify_all() - if result.get("executed") and on_cycle_executed is not None: - on_cycle_executed(result) - return result - - def run_loop( - self, - *, - run_decode_cycle_nonblocking: Callable[[], Dict[str, Any]], - ) -> None: - while True: - executed = run_decode_cycle_nonblocking() - if executed.get("executed"): - continue - wait_for_batch = self.active_batch is None - pending_jobs = self.take_pending_snapshot(wait_for_batch=wait_for_batch) - if pending_jobs: - with self.condition: - self.pending_jobs = pending_jobs + self.pending_jobs - self.condition.notify_all() - continue - time.sleep(self.micro_batch_wait_s) - - -class WorkerDecodeRuntimeTracker: - def __init__( - self, - runtime_callbacks: RuntimeStateCallbacks | None = None, - ) -> None: - self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() - self.total_cycles = 0 - self.prefill_cycles = 0 - self.step_cycles = 0 - - def get_counters(self) -> Dict[str, int]: - return { - "total_cycles": int(self.total_cycles), - "prefill_cycles": int(self.prefill_cycles), - "step_cycles": int(self.step_cycles), - } - - def record_cycle(self, result: Dict[str, Any]) -> None: - if not bool(result.get("executed")): - return - self.total_cycles += 1 - if bool(result.get("prefill_merge_executed")): - self.prefill_cycles += 1 - if bool(result.get("decode_step_executed")): - self.step_cycles += 1 - - def build_runtime_summary_locked( - self, - *, - legacy_shell: WorkerDecodeLegacyShell, - last_event: str, - ) -> Dict[str, Any]: - return legacy_shell.build_runtime_summary_locked( - total_cycles=int(self.total_cycles), - prefill_cycles=int(self.prefill_cycles), - step_cycles=int(self.step_cycles), - last_event=str(last_event), - ) - - def notify_runtime_update_locked( - self, - *, - legacy_shell: WorkerDecodeLegacyShell, - last_event: str, - ) -> None: - if self.runtime_callbacks.decode_runtime_update is None: - return - snapshot = self.build_runtime_summary_locked( - legacy_shell=legacy_shell, - last_event=last_event, - ) - self.runtime_callbacks.decode_runtime_update(snapshot) - - -class UnifiedSchedulerWorker: +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerFinalizeTask, SchedulerJobRegistry +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_completion import WorkerCompletionBridge +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_decode import WorkerDecodeExecutor, WorkerDecodeLegacyShell, WorkerDecodeRuntimeTracker +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_execution import WorkerExecutionMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_finalize import WorkerFinalizeExecutor +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_prepare import WorkerPrepareExecutor +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_runtime import WorkerRuntimeBookkeepingMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_submit import WorkerSubmitLifecycleMixin + + +class UnifiedSchedulerWorker( + WorkerSubmitLifecycleMixin, + WorkerRuntimeBookkeepingMixin, + WorkerExecutionMixin, +): def __init__( self, tts: TTS, @@ -949,562 +69,3 @@ class UnifiedSchedulerWorker: def _notify_worker_state_change(self) -> None: with self.condition: self.condition.notify_all() - - def _current_decode_backlog_locked(self) -> int: - return self.decode_legacy_shell.current_backlog_locked() - - def get_micro_batch_wait_s(self) -> float: - return float(self.micro_batch_wait_s) - - def is_engine_decode_control_enabled(self) -> bool: - return bool(self.engine_decode_control_enabled) - - def get_prepare_max_inflight(self) -> int: - return int(self.prepare_executor.get_max_inflight()) - - def get_capacity_limits(self) -> Dict[str, int]: - return { - "decode_backlog_max": int(self.decode_backlog_max), - "finalize_pending_max": int(self.finalize_pending_max), - } - - def get_finalize_batch_policy(self) -> Dict[str, Any]: - return dict(self.finalize_executor.get_batch_policy()) - - def get_decode_runtime_counters(self) -> Dict[str, int]: - with self.condition: - return self.decode_runtime_tracker.get_counters() - - def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]: - decode_backlog = self._current_decode_backlog_locked() - finalize_pending = int(self.finalize_executor.get_pending_count()) - prepare_inflight = int(self.prepare_executor.snapshot()["inflight"]) - blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max - blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max - return ( - not blocked_decode and not blocked_finalize, - { - "decode_backlog": decode_backlog, - "finalize_pending": finalize_pending, - "prepare_inflight": prepare_inflight, - "decode_backlog_max": int(self.decode_backlog_max), - "finalize_pending_max": int(self.finalize_pending_max), - }, - ) - - def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]: - start = time.perf_counter() - deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) - last_snapshot: Dict[str, int] = {} - while True: - with self.condition: - allowed, snapshot = self._can_accept_submit_locked() - last_snapshot = snapshot - if allowed: - return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot - if deadline is not None and time.perf_counter() >= deadline: - raise TimeoutError( - "scheduler submit admission timeout " - f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})" - ) - self.condition.wait(timeout=self.micro_batch_wait_s) - - def _admission_snapshot_locked(self) -> Dict[str, int]: - _, snapshot = self._can_accept_submit_locked() - return snapshot - - async def submit_async( - self, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None = None, - done_future: asyncio.Future | None = None, - engine_request_id: str | None = None, - timeout_sec: float | None = None, - skip_capacity_wait: bool = False, - admission_wait_ms_override: float | None = None, - admission_snapshot_override: Dict[str, Any] | None = None, - engine_policy_wait_ms: float = 0.0, - engine_dispatch_wait_ms: float = 0.0, - enqueue_pending: bool = True, - ) -> SchedulerPendingJob: - return await asyncio.to_thread( - self.submit, - state, - speed_factor, - sample_steps, - media_type, - prepare_wall_ms, - prepare_profile_total_ms, - done_loop, - done_future, - engine_request_id, - timeout_sec, - skip_capacity_wait, - admission_wait_ms_override, - admission_snapshot_override, - engine_policy_wait_ms, - engine_dispatch_wait_ms, - enqueue_pending, - ) - - def snapshot(self) -> dict: - with self.condition: - prepare_state = self.prepare_executor.snapshot() - finalize_state = self.finalize_executor.snapshot() - shell_state = self.decode_legacy_shell.snapshot_locked() - decode_runtime_counters = self.decode_runtime_tracker.get_counters() - engine_owned_decode_state = bool(self.engine_decode_control_enabled) - active_batch_summary = shell_state.get("executor_local_active_batch") - executor_local_pending_jobs = int(shell_state.get("executor_local_pending_jobs", 0)) - executor_local_running_requests = int(shell_state.get("executor_local_running_requests", 0)) - executor_local_has_work = bool(shell_state.get("executor_local_has_work", False)) - return { - "pending_jobs": 0 if engine_owned_decode_state else executor_local_pending_jobs, - "running_requests": 0 if engine_owned_decode_state else executor_local_running_requests, - "engine_decode_control_enabled": bool(self.engine_decode_control_enabled), - "legacy_state_owner_mode": not engine_owned_decode_state, - "decode_state_owner": "engine" if engine_owned_decode_state else "worker", - "decode_runtime_has_work": False if engine_owned_decode_state else executor_local_has_work, - "executor_local_pending_jobs": executor_local_pending_jobs, - "executor_local_running_requests": executor_local_running_requests, - "executor_local_has_work": executor_local_has_work, - "decode_runtime_total_cycles": int(decode_runtime_counters.get("total_cycles", 0)), - "decode_runtime_prefill_cycles": int(decode_runtime_counters.get("prefill_cycles", 0)), - "decode_runtime_step_cycles": int(decode_runtime_counters.get("step_cycles", 0)), - "prepare_inflight": prepare_state["inflight"], - "prepare_peak_inflight": prepare_state["peak_inflight"], - "prepare_max_inflight": prepare_state.get("max_inflight", 0), - "prepare_state": dict(prepare_state), - **finalize_state, - "decode_backlog_max": self.decode_backlog_max, - "finalize_pending_max": self.finalize_pending_max, - "active_batch": {} if engine_owned_decode_state else active_batch_summary, - "executor_local_active_batch": active_batch_summary if engine_owned_decode_state else None, - "total_submitted": self.job_registry.submitted_count(), - "total_finished": self.job_registry.finished_count(), - "drained": self.is_drained(), - } - - def is_drained(self) -> bool: - with self.condition: - return ( - self.decode_legacy_shell.is_idle_locked() - and self.job_registry.is_empty() - and self.prepare_executor.is_idle() - and self.finalize_executor.is_idle() - ) - - def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: - deadline = time.perf_counter() + max(0.0, timeout_sec) - while time.perf_counter() < deadline: - if self.is_drained(): - return True - time.sleep(poll_interval_sec) - return self.is_drained() - - def submit( - self, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None = None, - done_future: asyncio.Future | None = None, - engine_request_id: str | None = None, - timeout_sec: float | None = None, - skip_capacity_wait: bool = False, - admission_wait_ms_override: float | None = None, - admission_snapshot_override: Dict[str, Any] | None = None, - engine_policy_wait_ms: float = 0.0, - engine_dispatch_wait_ms: float = 0.0, - enqueue_pending: bool = True, - ) -> SchedulerPendingJob: - if skip_capacity_wait: - with self.condition: - admission_snapshot = ( - dict(admission_snapshot_override) - if admission_snapshot_override is not None - else dict(self._admission_snapshot_locked()) - ) - admission_wait_ms = 0.0 if admission_wait_ms_override is None else float(admission_wait_ms_override) - else: - admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec) - job = SchedulerPendingJob( - request_id=state.request_id, - state=state, - done_event=threading.Event(), - done_loop=done_loop, - done_future=done_future, - enqueue_time=time.perf_counter(), - speed_factor=float(speed_factor), - sample_steps=int(sample_steps), - media_type=media_type, - admission_wait_ms=float(admission_wait_ms), - engine_policy_wait_ms=float(engine_policy_wait_ms), - engine_dispatch_wait_ms=float(engine_dispatch_wait_ms), - prepare_wall_ms=float(prepare_wall_ms), - prepare_profile_total_ms=float(prepare_profile_total_ms), - engine_request_id=engine_request_id or state.request_id, - ) - with self.condition: - self.job_registry.register(job, keep_job=not self.engine_decode_control_enabled) - if enqueue_pending: - self.decode_legacy_shell.enqueue_pending_job_locked(job) - self.condition.notify_all() - if enqueue_pending: - self._notify_decode_runtime_state("submit") - self._runtime_update( - job.engine_request_id, - EngineStatus.QUEUED, - { - "scheduler_request_id": job.request_id, - "decode_admission_wait_ms": float(admission_wait_ms), - "engine_policy_wait_ms": float(engine_policy_wait_ms), - "engine_dispatch_wait_ms": float(engine_dispatch_wait_ms), - "admission_snapshot": dict(admission_snapshot), - }, - ) - return job - - async def prepare_state_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> tuple[T2SRequestState, float, float]: - return await self.prepare_executor.prepare_state_profiled_async(spec, prepare_submit_at) - - async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: - return await self.prepare_executor.prepare_states_batch_async(specs) - - async def prepare_cpu_stage_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> PreparedCpuStage: - return await self.prepare_executor.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) - - async def prepare_gpu_stage_profiled_async( - self, - cpu_stage: PreparedCpuStage, - ) -> tuple[T2SRequestState, float, float]: - return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage) - - def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: - with self.condition: - for job in pending_jobs: - job.first_schedule_time = float(started_at) - self._runtime_update( - job.engine_request_id, - EngineStatus.GPU_PREPARING, - {"scheduler_request_id": job.request_id, "prefill_started_at": float(started_at)}, - ) - - def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - job.prefill_ms += delta_ms - - def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - job.merge_ms += delta_ms - - def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - activate_request_ids: List[str] = [] - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - if job.decode_steps == 0: - activate_request_ids.append(job.engine_request_id) - job.decode_ms += delta_ms - job.decode_steps += 1 - for engine_request_id in activate_request_ids: - self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None) - - def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - job.finalize_wait_ms += float(delta_ms) - - def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: - if not items: - return - enqueued_at = time.perf_counter() - tasks: List[SchedulerFinalizeTask] = [] - with self.condition: - for item in items: - job = self.job_registry.get(item.request_id) - if job is not None: - self._runtime_update( - job.engine_request_id, - EngineStatus.READY_FOR_FINALIZE, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - }, - ) - tasks.append(SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at)) - self.finalize_executor.enqueue_tasks(tasks) - - def begin_finalize_execution(self, task_count: int) -> None: - self.finalize_executor.begin_execution(task_count) - - def end_finalize_execution(self, task_count: int) -> None: - self.finalize_executor.end_execution(task_count) - - def record_external_job_done(self, request_id: str) -> None: - with self.condition: - self.job_registry.mark_finished_and_remove(request_id) - self.condition.notify_all() - - def synthesize_finalize_jobs( - self, - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], - ) -> tuple[float, List[tuple[int, np.ndarray]]]: - return self.finalize_executor.synthesize_finalize_jobs(jobs_and_items) - - def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: - self.completion_bridge.complete_finalize_task( - condition=self.condition, - job_registry=self.job_registry, - job=job, - item=item, - sample_rate=sample_rate, - audio_data=audio_data, - ) - - def _finalize_error(self, request_ids: List[str], error: str) -> None: - self.completion_bridge.fail_jobs( - condition=self.condition, - job_registry=self.job_registry, - request_ids=request_ids, - error=error, - ) - - @staticmethod - def _resolve_done_future(job: SchedulerPendingJob) -> None: - future = job.done_future - if future is None or future.done(): - return - future.set_result(job) - - def _notify_done_future(self, job: SchedulerPendingJob) -> None: - self.completion_bridge.notify_done_future(job) - - def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None: - if request_id is None or self.runtime_callbacks.update is None: - return - self.runtime_callbacks.update(request_id, status, extra) - - def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: - self.completion_bridge.runtime_complete(request_id, extra) - - def _runtime_fail(self, request_id: str | None, error: str) -> None: - self.completion_bridge.runtime_fail(request_id, error) - - def _build_decode_runtime_summary_locked(self, last_event: str) -> Dict[str, Any]: - return self.decode_runtime_tracker.build_runtime_summary_locked( - legacy_shell=self.decode_legacy_shell, - last_event=str(last_event), - ) - - def _notify_decode_runtime_state(self, last_event: str) -> None: - with self.condition: - self.decode_runtime_tracker.notify_runtime_update_locked( - legacy_shell=self.decode_legacy_shell, - last_event=str(last_event), - ) - - def _record_decode_runtime_cycle(self, result: Dict[str, Any]) -> None: - with self.condition: - self.decode_runtime_tracker.record_cycle(result) - - def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.decode_legacy_shell.take_pending_snapshot(wait_for_batch) - - def _take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.decode_legacy_shell.take_pending_snapshot_nonblocking(wait_for_batch) - - def has_decode_runtime_work(self) -> bool: - return self.decode_legacy_shell.has_decode_runtime_work() - - def execute_prefill_merge( - self, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - return self.decode_executor.execute_prefill_merge( - pending_jobs=pending_jobs, - active_batch=active_batch, - mark_prefill_started=self._mark_prefill_started, - add_prefill_time=None if external_bookkeeping else self._add_prefill_time, - add_merge_time=None if external_bookkeeping else self._add_merge_time, - enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, - finalize_error=None if external_bookkeeping else self._finalize_error, - ) - - def execute_decode_step( - self, - active_batch: Optional[T2SActiveBatch], - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - return self.decode_executor.execute_decode_step( - active_batch=active_batch, - add_decode_time=None if external_bookkeeping else self._add_decode_time, - enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, - finalize_error=None if external_bookkeeping else self._finalize_error, - ) - - def execute_decode_cycle( - self, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_executor.execute_decode_cycle( - pending_jobs=pending_jobs, - active_batch=active_batch, - mark_prefill_started=self._mark_prefill_started, - add_prefill_time=None if external_bookkeeping else self._add_prefill_time, - add_merge_time=None if external_bookkeeping else self._add_merge_time, - add_decode_time=None if external_bookkeeping else self._add_decode_time, - enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, - finalize_error=None if external_bookkeeping else self._finalize_error, - ) - self._record_decode_runtime_cycle(result) - return result - - def run_prefill_merge_once_nonblocking( - self, - external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, - external_active_batch: Optional[T2SActiveBatch] = None, - emit_runtime_state: bool = True, - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_legacy_shell.run_prefill_merge_once_nonblocking( - external_pending_jobs=external_pending_jobs, - external_active_batch=external_active_batch, - execute_prefill_merge=lambda batch_jobs, batch_state: self.execute_prefill_merge( - pending_jobs=batch_jobs, - active_batch=batch_state, - external_bookkeeping=external_bookkeeping, - ), - ) - if emit_runtime_state: - self._notify_decode_runtime_state("prefill_merge") - return result - - def run_decode_step_once_nonblocking( - self, - external_active_batch: Optional[T2SActiveBatch] = None, - emit_runtime_state: bool = True, - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_legacy_shell.run_decode_step_once_nonblocking( - external_active_batch=external_active_batch, - execute_decode_step=lambda batch_state: self.execute_decode_step( - active_batch=batch_state, - external_bookkeeping=external_bookkeeping, - ), - ) - if emit_runtime_state: - self._notify_decode_runtime_state("decode_step") - return result - - def run_decode_cycle_nonblocking( - self, - external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, - external_active_batch: Optional[T2SActiveBatch] = None, - emit_runtime_state: bool = True, - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_legacy_shell.run_decode_cycle_nonblocking( - external_pending_jobs=external_pending_jobs, - external_active_batch=external_active_batch, - execute_decode_cycle=lambda batch_jobs, batch_state: self.execute_decode_cycle( - pending_jobs=batch_jobs, - active_batch=batch_state, - external_bookkeeping=external_bookkeeping, - ), - on_cycle_executed=None, - ) - if result.get("executed") and emit_runtime_state: - self._notify_decode_runtime_state("decode_cycle") - return result - - def execute_finalize_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: - if not tasks: - return - try: - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] - with self.condition: - for task in tasks: - job = self.job_registry.get(task.request_id) - if job is None: - continue - jobs_and_items.append((job, task.item)) - if not jobs_and_items: - return - now = time.perf_counter() - for task in tasks: - self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) - for job, item in jobs_and_items: - self._runtime_update( - job.engine_request_id, - EngineStatus.FINALIZING, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - }, - ) - synth_ms, batch_results = self.synthesize_finalize_jobs(jobs_and_items) - with self.condition: - for job, _ in jobs_and_items: - tracked_job = self.job_registry.get(job.request_id) - if tracked_job is not None: - tracked_job.synth_ms += synth_ms - for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): - self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) - except Exception as exc: - self._finalize_error([task.request_id for task in tasks], str(exc)) - finally: - self.finalize_executor.end_execution(len(tasks)) - - def _run_finalize_loop(self) -> None: - while True: - tasks = self.finalize_executor.take_task_batch_blocking() - self.execute_finalize_tasks(tasks) - - def _run_loop(self) -> None: - self.decode_legacy_shell.run_loop( - run_decode_cycle_nonblocking=lambda: self.run_decode_cycle_nonblocking() - ) - - diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_completion.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_completion.py new file mode 100644 index 00000000..da2c057a --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_completion.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import threading +import time +from typing import Any, Callable, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerJobRegistry, SchedulerPendingJob + + +class WorkerCompletionBridge: + def __init__(self, runtime_callbacks: RuntimeStateCallbacks | None = None) -> None: + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(job) + + def notify_done_future(self, job: SchedulerPendingJob) -> None: + if job.done_loop is None or job.done_future is None: + return + try: + job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) + except RuntimeError: + pass + + def runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + if request_id is None or self.runtime_callbacks.complete is None: + return + self.runtime_callbacks.complete(request_id, extra) + + def runtime_fail(self, request_id: str | None, error: str) -> None: + if request_id is None or self.runtime_callbacks.fail is None: + return + self.runtime_callbacks.fail(request_id, error) + + @staticmethod + def build_completed_job_result( + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + finished_at: float | None = None, + ) -> Dict[str, Any]: + finished_at = float(time.perf_counter() if finished_at is None else finished_at) + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + worker_residual_ms = max( + 0.0, + worker_total_ms + - queue_wait_ms + - job.prefill_ms + - job.merge_ms + - job.decode_ms + - job.finalize_wait_ms + - job.synth_ms, + ) + worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + job.result_ready_time = finished_at + prepare_profile = dict(job.state.prepare_profile) + result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "decode_admission_wait_ms": float(job.admission_wait_ms), + "engine_policy_wait_ms": float(job.engine_policy_wait_ms), + "engine_dispatch_wait_ms": float(job.engine_dispatch_wait_ms), + "prepare_ms": job.prepare_wall_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile_total_ms": job.prepare_profile_total_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "merge_ms": job.merge_ms, + "decode_ms": job.decode_ms, + "finalize_wait_ms": job.finalize_wait_ms, + "synth_ms": job.synth_ms, + "worker_residual_ms": worker_residual_ms, + "worker_other_ms": worker_other_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.result = result + return result + + @staticmethod + def build_runtime_complete_payload( + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + ) -> Dict[str, Any]: + return { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "sample_rate": int(sample_rate), + "worker_profile": dict(job.result or {}), + } + + def complete_job( + self, + job: SchedulerPendingJob, + *, + runtime_request_id: str | None, + runtime_extra: Optional[Dict[str, Any]] = None, + remove_job: Callable[[], None] | None = None, + on_job_finished: Callable[[], None] | None = None, + notify_waiters: Callable[[], None] | None = None, + ) -> None: + job.done_event.set() + self.notify_done_future(job) + if remove_job is not None: + remove_job() + if on_job_finished is not None: + on_job_finished() + if notify_waiters is not None: + notify_waiters() + self.runtime_complete(runtime_request_id, runtime_extra) + + def fail_job( + self, + job: SchedulerPendingJob, + *, + error: str, + remove_job: Callable[[], None] | None = None, + on_job_finished: Callable[[], None] | None = None, + notify_waiters: Callable[[], None] | None = None, + ) -> None: + job.error = str(error) + job.done_event.set() + self.notify_done_future(job) + if remove_job is not None: + remove_job() + if on_job_finished is not None: + on_job_finished() + if notify_waiters is not None: + notify_waiters() + self.runtime_fail(job.engine_request_id, str(error)) + + def complete_finalize_task( + self, + *, + condition: threading.Condition, + job_registry: SchedulerJobRegistry, + job: SchedulerPendingJob, + item: T2SFinishedItem, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + runtime_extra: Optional[Dict[str, Any]] = None + with condition: + if job_registry.get(item.request_id) is not job: + return + self.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) + runtime_extra = self.build_runtime_complete_payload(job, item, sample_rate=sample_rate) + self.complete_job( + job, + runtime_request_id=job.engine_request_id, + runtime_extra=runtime_extra, + on_job_finished=lambda: job_registry.mark_finished_and_remove(item.request_id), + notify_waiters=condition.notify_all, + ) + + def fail_jobs( + self, + *, + condition: threading.Condition, + job_registry: SchedulerJobRegistry, + request_ids: List[str], + error: str, + ) -> None: + if not request_ids: + return + with condition: + for request_id in request_ids: + job = job_registry.get(request_id) + if job is None: + continue + self.fail_job( + job, + error=error, + on_job_finished=lambda rid=request_id: job_registry.mark_finished_and_remove(rid), + ) + condition.notify_all() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_decode.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_decode.py new file mode 100644 index 00000000..784f71d0 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_decode.py @@ -0,0 +1,430 @@ +from __future__ import annotations + +import threading +import time +from typing import Any, Callable, Dict, List, Optional + +import torch + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( + T2SActiveBatch, + T2SFinishedItem, + decode_one_step, + merge_active_batches, + run_prefill_active_batch, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerPendingJob + + +class WorkerDecodeExecutor: + def __init__(self, tts: TTS, max_steps: int) -> None: + self.tts = tts + self.max_steps = int(max_steps) + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def execute_prefill_merge( + self, + *, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], + add_prefill_time: Callable[[List[str], float], None] | None, + add_merge_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + if not pending_jobs: + return { + "executed": False, + "active_batch": active_batch, + "pending_jobs": [], + "prefill_elapsed_s": 0.0, + "merge_elapsed_s": 0.0, + "finished_items": [], + "error": None, + "error_request_ids": [], + } + admitted_finished: List[T2SFinishedItem] = [] + prefill_elapsed_s = 0.0 + merge_elapsed_s = 0.0 + error: str | None = None + error_request_ids: List[str] = [] + try: + self._sync_device() + prefill_start = time.perf_counter() + mark_prefill_started(pending_jobs, prefill_start) + admitted_active_batch, admitted_finished = run_prefill_active_batch( + self.tts.t2s_model.model, + [job.state for job in pending_jobs], + max_steps=self.max_steps, + ) + self._sync_device() + prefill_elapsed_s = time.perf_counter() - prefill_start + if add_prefill_time is not None: + add_prefill_time([job.request_id for job in pending_jobs], prefill_elapsed_s) + if enqueue_finished is not None: + enqueue_finished(admitted_finished) + merge_start = time.perf_counter() + active_batch = merge_active_batches( + self.tts.t2s_model.model, + active_batch, + admitted_active_batch, + ) + merge_elapsed_s = time.perf_counter() - merge_start + if add_merge_time is not None: + add_merge_time( + [] if active_batch is None else list(active_batch.request_ids), + merge_elapsed_s, + ) + except Exception as exc: + error = str(exc) + error_request_ids = [job.request_id for job in pending_jobs] + if finalize_error is not None: + finalize_error(error_request_ids, error) + return { + "executed": True, + "active_batch": active_batch, + "pending_jobs": list(pending_jobs), + "prefill_elapsed_s": float(prefill_elapsed_s), + "merge_elapsed_s": float(merge_elapsed_s), + "finished_items": list(admitted_finished), + "error": error, + "error_request_ids": error_request_ids, + } + + def execute_decode_step( + self, + *, + active_batch: Optional[T2SActiveBatch], + add_decode_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + if active_batch is None: + return { + "executed": False, + "active_batch": None, + "request_ids": [], + "decode_elapsed_s": 0.0, + "finished_items": [], + "error": None, + "error_request_ids": [], + } + active_request_ids: List[str] = [] + step_finished: List[T2SFinishedItem] = [] + decode_elapsed_s = 0.0 + error: str | None = None + error_request_ids: List[str] = [] + try: + active_request_ids = [state.request_id for state in active_batch.states] + self._sync_device() + decode_start = time.perf_counter() + active_batch, step_finished = decode_one_step( + self.tts.t2s_model.model, + active_batch, + max_steps=self.max_steps, + ) + self._sync_device() + decode_elapsed_s = time.perf_counter() - decode_start + if add_decode_time is not None: + add_decode_time(active_request_ids, decode_elapsed_s) + if enqueue_finished is not None: + enqueue_finished(step_finished) + except Exception as exc: + error = str(exc) + error_request_ids = list(active_request_ids) + if finalize_error is not None: + finalize_error(error_request_ids, error) + active_batch = None + return { + "executed": True, + "active_batch": active_batch, + "request_ids": active_request_ids, + "decode_elapsed_s": float(decode_elapsed_s), + "finished_items": list(step_finished), + "error": error, + "error_request_ids": error_request_ids, + } + + def execute_decode_cycle( + self, + *, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], + add_prefill_time: Callable[[List[str], float], None] | None, + add_merge_time: Callable[[List[str], float], None] | None, + add_decode_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + result = { + "executed": False, + "prefill_merge_executed": False, + "decode_step_executed": False, + "active_batch": active_batch, + "prefill_phase": {}, + "decode_phase": {}, + } + prefill_phase = self.execute_prefill_merge( + pending_jobs=list(pending_jobs), + active_batch=result["active_batch"], + mark_prefill_started=mark_prefill_started, + add_prefill_time=add_prefill_time, + add_merge_time=add_merge_time, + enqueue_finished=enqueue_finished, + finalize_error=finalize_error, + ) + prefill_executed = bool(prefill_phase.get("executed", False)) + result["prefill_phase"] = prefill_phase + result["active_batch"] = prefill_phase.get("active_batch") + if prefill_executed: + result["executed"] = True + result["prefill_merge_executed"] = True + decode_phase = self.execute_decode_step( + active_batch=result["active_batch"], + add_decode_time=add_decode_time, + enqueue_finished=enqueue_finished, + finalize_error=finalize_error, + ) + decode_executed = bool(decode_phase.get("executed", False)) + result["decode_phase"] = decode_phase + result["active_batch"] = decode_phase.get("active_batch") + if decode_executed: + result["executed"] = True + result["decode_step_executed"] = True + return result + + +class WorkerDecodeLegacyShell: + def __init__(self, condition: threading.Condition, micro_batch_wait_s: float) -> None: + self.condition = condition + self.micro_batch_wait_s = float(micro_batch_wait_s) + self.pending_jobs: List[SchedulerPendingJob] = [] + self.active_batch: T2SActiveBatch | None = None + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any] | None: + if active_batch is None: + return None + return { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": ( + int(active_batch.step_indices.max().item()) + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0 + else 0 + ), + } + + def current_backlog_locked(self) -> int: + running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids) + return int(len(self.pending_jobs) + running_requests) + + def enqueue_pending_job_locked(self, job: SchedulerPendingJob) -> None: + self.pending_jobs.append(job) + + def snapshot_locked(self) -> Dict[str, Any]: + active_batch_summary = self._summarize_active_batch(self.active_batch) + executor_local_pending_jobs = int(len(self.pending_jobs)) + executor_local_running_requests = 0 if self.active_batch is None else int(len(self.active_batch.request_ids)) + executor_local_has_work = bool(self.pending_jobs or self.active_batch is not None) + return { + "executor_local_pending_jobs": executor_local_pending_jobs, + "executor_local_running_requests": executor_local_running_requests, + "executor_local_has_work": executor_local_has_work, + "executor_local_active_batch": active_batch_summary, + } + + def is_idle_locked(self) -> bool: + return self.active_batch is None and not self.pending_jobs + + def take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs and self.active_batch is None: + self.condition.wait(timeout=self.micro_batch_wait_s) + elif wait_for_batch and self.pending_jobs: + self.condition.wait(timeout=self.micro_batch_wait_s) + if not self.pending_jobs: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs: + return [] + if wait_for_batch: + oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) + if (time.perf_counter() - oldest_enqueue_time) < self.micro_batch_wait_s: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def has_decode_runtime_work(self) -> bool: + with self.condition: + return bool(self.pending_jobs or self.active_batch is not None) + + def build_runtime_summary_locked(self, *, total_cycles: int, prefill_cycles: int, step_cycles: int, last_event: str) -> Dict[str, Any]: + active_request_ids = [] if self.active_batch is None else list(self.active_batch.request_ids) + decode_step_index_max = 0 + prefill_done = False + if self.active_batch is not None: + prefill_done = bool(self.active_batch.prefill_done) + if self.active_batch.step_indices is not None and self.active_batch.step_indices.numel() > 0: + decode_step_index_max = int(self.active_batch.step_indices.max().item()) + return { + "pending_jobs": int(len(self.pending_jobs)), + "active_request_count": int(len(active_request_ids)), + "active_request_ids": active_request_ids[:32], + "prefill_done": bool(prefill_done), + "decode_step_index_max": int(decode_step_index_max), + "total_cycles": int(total_cycles), + "prefill_cycles": int(prefill_cycles), + "step_cycles": int(step_cycles), + "has_work": bool(self.pending_jobs or self.active_batch is not None), + "last_event": str(last_event), + "updated_at": float(time.perf_counter()), + } + + def run_prefill_merge_once_nonblocking( + self, + *, + external_pending_jobs: Optional[List[SchedulerPendingJob]], + external_active_batch: Optional[T2SActiveBatch], + execute_prefill_merge: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], + ) -> Dict[str, Any]: + pending_jobs = ( + list(external_pending_jobs) + if external_pending_jobs is not None + else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) + ) + active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch + result = execute_prefill_merge(pending_jobs, active_batch) + if external_pending_jobs is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + return result + + def run_decode_step_once_nonblocking( + self, + *, + external_active_batch: Optional[T2SActiveBatch], + execute_decode_step: Callable[[Optional[T2SActiveBatch]], Dict[str, Any]], + ) -> Dict[str, Any]: + active_batch = self.active_batch if external_active_batch is None else external_active_batch + result = execute_decode_step(active_batch) + if external_active_batch is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + return result + + def run_decode_cycle_nonblocking( + self, + *, + external_pending_jobs: Optional[List[SchedulerPendingJob]], + external_active_batch: Optional[T2SActiveBatch], + execute_decode_cycle: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], + on_cycle_executed: Callable[[Dict[str, Any]], None] | None, + ) -> Dict[str, Any]: + pending_jobs = ( + list(external_pending_jobs) + if external_pending_jobs is not None + else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) + ) + active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch + result = execute_decode_cycle(pending_jobs, active_batch) + if external_pending_jobs is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + if result.get("executed") and on_cycle_executed is not None: + on_cycle_executed(result) + return result + + def run_loop( + self, + *, + run_decode_cycle_nonblocking: Callable[[], Dict[str, Any]], + ) -> None: + while True: + executed = run_decode_cycle_nonblocking() + if executed.get("executed"): + continue + wait_for_batch = self.active_batch is None + pending_jobs = self.take_pending_snapshot(wait_for_batch=wait_for_batch) + if pending_jobs: + with self.condition: + self.pending_jobs = pending_jobs + self.pending_jobs + self.condition.notify_all() + continue + time.sleep(self.micro_batch_wait_s) + + +class WorkerDecodeRuntimeTracker: + def __init__( + self, + runtime_callbacks: RuntimeStateCallbacks | None = None, + ) -> None: + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + self.total_cycles = 0 + self.prefill_cycles = 0 + self.step_cycles = 0 + + def get_counters(self) -> Dict[str, int]: + return { + "total_cycles": int(self.total_cycles), + "prefill_cycles": int(self.prefill_cycles), + "step_cycles": int(self.step_cycles), + } + + def record_cycle(self, result: Dict[str, Any]) -> None: + if not bool(result.get("executed")): + return + self.total_cycles += 1 + if bool(result.get("prefill_merge_executed")): + self.prefill_cycles += 1 + if bool(result.get("decode_step_executed")): + self.step_cycles += 1 + + def build_runtime_summary_locked( + self, + *, + legacy_shell: WorkerDecodeLegacyShell, + last_event: str, + ) -> Dict[str, Any]: + return legacy_shell.build_runtime_summary_locked( + total_cycles=int(self.total_cycles), + prefill_cycles=int(self.prefill_cycles), + step_cycles=int(self.step_cycles), + last_event=str(last_event), + ) + + def notify_runtime_update_locked( + self, + *, + legacy_shell: WorkerDecodeLegacyShell, + last_event: str, + ) -> None: + if self.runtime_callbacks.decode_runtime_update is None: + return + snapshot = self.build_runtime_summary_locked( + legacy_shell=legacy_shell, + last_event=last_event, + ) + self.runtime_callbacks.decode_runtime_update(snapshot) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_execution.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_execution.py new file mode 100644 index 00000000..465f7a2c --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_execution.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import time +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch, T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob + + +class WorkerExecutionMixin: + def execute_prefill_merge( + self, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + return self.decode_executor.execute_prefill_merge( + pending_jobs=pending_jobs, + active_batch=active_batch, + mark_prefill_started=self._mark_prefill_started, + add_prefill_time=None if external_bookkeeping else self._add_prefill_time, + add_merge_time=None if external_bookkeeping else self._add_merge_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + + def execute_decode_step( + self, + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + return self.decode_executor.execute_decode_step( + active_batch=active_batch, + add_decode_time=None if external_bookkeeping else self._add_decode_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + + def execute_decode_cycle( + self, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_executor.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=active_batch, + mark_prefill_started=self._mark_prefill_started, + add_prefill_time=None if external_bookkeeping else self._add_prefill_time, + add_merge_time=None if external_bookkeeping else self._add_merge_time, + add_decode_time=None if external_bookkeeping else self._add_decode_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + self._record_decode_runtime_cycle(result) + return result + + def run_prefill_merge_once_nonblocking( + self, + external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_prefill_merge_once_nonblocking( + external_pending_jobs=external_pending_jobs, + external_active_batch=external_active_batch, + execute_prefill_merge=lambda batch_jobs, batch_state: self.execute_prefill_merge( + pending_jobs=batch_jobs, + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + ) + if emit_runtime_state: + self._notify_decode_runtime_state("prefill_merge") + return result + + def run_decode_step_once_nonblocking( + self, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_decode_step_once_nonblocking( + external_active_batch=external_active_batch, + execute_decode_step=lambda batch_state: self.execute_decode_step( + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + ) + if emit_runtime_state: + self._notify_decode_runtime_state("decode_step") + return result + + def run_decode_cycle_nonblocking( + self, + external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_decode_cycle_nonblocking( + external_pending_jobs=external_pending_jobs, + external_active_batch=external_active_batch, + execute_decode_cycle=lambda batch_jobs, batch_state: self.execute_decode_cycle( + pending_jobs=batch_jobs, + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + on_cycle_executed=None, + ) + if result.get("executed") and emit_runtime_state: + self._notify_decode_runtime_state("decode_cycle") + return result + + def execute_finalize_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + with self.condition: + for task in tasks: + job = self.job_registry.get(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return + now = time.perf_counter() + for task in tasks: + self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) + for job, item in jobs_and_items: + self._runtime_update( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.synthesize_finalize_jobs(jobs_and_items) + with self.condition: + for job, _ in jobs_and_items: + tracked_job = self.job_registry.get(job.request_id) + if tracked_job is not None: + tracked_job.synth_ms += synth_ms + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self._finalize_error([task.request_id for task in tasks], str(exc)) + finally: + self.finalize_executor.end_execution(len(tasks)) + + def _run_finalize_loop(self) -> None: + while True: + tasks = self.finalize_executor.take_task_batch_blocking() + self.execute_finalize_tasks(tasks) + + def _run_loop(self) -> None: + self.decode_legacy_shell.run_loop( + run_decode_cycle_nonblocking=lambda: self.run_decode_cycle_nonblocking() + ) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py new file mode 100644 index 00000000..4f5833fd --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py @@ -0,0 +1,234 @@ +from __future__ import annotations + +import os +import threading +import time +from collections import deque +from typing import Any, Callable, Deque, Dict, List + +import numpy as np +import torch + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import SchedulerFinalizeTask, SchedulerPendingJob + + +class WorkerFinalizeExecutor: + def __init__( + self, + tts: TTS, + on_state_change: Callable[[], None] | None = None, + external_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, + ) -> None: + self.tts = tts + self.on_state_change = on_state_change + self.external_submit = external_submit + self.condition = threading.Condition() + self.pending_tasks: Deque[SchedulerFinalizeTask] = deque() + self.pending_peak = 0 + self.inflight = 0 + self.inflight_peak = 0 + self.worker_count = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) + self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() + self.batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) + self.batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) + + def _notify_state_change(self) -> None: + if self.on_state_change is None: + return + try: + self.on_state_change() + except Exception: + pass + + def get_worker_count(self) -> int: + return int(self.worker_count) + + def get_batch_policy(self) -> Dict[str, Any]: + return { + "finalize_mode": str(self.finalize_mode), + "finalize_batch_max_items": int(self.batch_max_items), + "finalize_batch_wait_s": float(self.batch_wait_s), + } + + def get_pending_count(self) -> int: + with self.condition: + return int(len(self.pending_tasks)) + + def snapshot(self) -> Dict[str, Any]: + with self.condition: + return { + "finalize_pending": int(len(self.pending_tasks)), + "finalize_pending_peak": int(self.pending_peak), + "finalize_inflight": int(self.inflight), + "finalize_inflight_peak": int(self.inflight_peak), + "finalize_workers": int(self.worker_count), + "finalize_mode": str(self.finalize_mode), + "finalize_batch_max_items": int(self.batch_max_items), + "finalize_batch_wait_ms": float(self.batch_wait_s * 1000.0), + } + + def is_idle(self) -> bool: + with self.condition: + return self.inflight <= 0 and not self.pending_tasks + + def enqueue_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + if self.external_submit is not None: + self.external_submit(tasks) + self._notify_state_change() + return + with self.condition: + for task in tasks: + self.pending_tasks.append(task) + self.pending_peak = max(self.pending_peak, len(self.pending_tasks)) + self.condition.notify_all() + self._notify_state_change() + + def begin_execution(self, task_count: int) -> None: + if task_count <= 0: + return + with self.condition: + self.inflight += int(task_count) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self.condition.notify_all() + self._notify_state_change() + + def end_execution(self, task_count: int) -> None: + with self.condition: + self.inflight = max(0, self.inflight - int(task_count)) + self.condition.notify_all() + self._notify_state_change() + + def take_task_batch_blocking(self) -> List[SchedulerFinalizeTask]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + selected_tasks = [self.pending_tasks.popleft()] + if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: + self.inflight += len(selected_tasks) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self._notify_state_change() + return selected_tasks + batch_deadline = time.perf_counter() + self.batch_wait_s + while len(selected_tasks) < self.batch_max_items: + if not self.pending_tasks: + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.condition.wait(timeout=remaining) + continue + first_task = selected_tasks[0] + matched_index = None + for index, task in enumerate(self.pending_tasks): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is not None: + selected_tasks.append(self.pending_tasks[matched_index]) + del self.pending_tasks[matched_index] + continue + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.condition.wait(timeout=remaining) + self.inflight += len(selected_tasks) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self._notify_state_change() + return selected_tasks + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: + audio_fragment = self.tts.synthesize_audio_request_local( + semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), + phones=job.state.phones.detach().clone().unsqueeze(0), + prompt_semantic=job.state.prompt_semantic.detach().clone(), + prompt_phones=job.state.prompt_phones.detach().clone(), + refer_spec=( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ), + raw_audio=job.state.raw_audio.detach().clone(), + raw_sr=int(job.state.raw_sr), + speed=float(job.speed_factor), + sample_steps=int(job.sample_steps), + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + return self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + + def _synthesize_finished_audio_batch( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> List[tuple[int, np.ndarray]]: + semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] + phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] + refer_specs = [] + speeds = [] + sample_steps_list = [] + for job, _ in jobs_and_items: + refer_specs.append( + ( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ) + ) + speeds.append(float(job.speed_factor)) + sample_steps_list.append(int(job.sample_steps)) + audio_fragments = self.tts.synthesize_audio_requests_local_batched( + semantic_tokens_list=semantic_tokens_list, + phones_list=phones_list, + refer_specs=refer_specs, + speeds=speeds, + sample_steps_list=sample_steps_list, + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + results: List[tuple[int, np.ndarray]] = [] + for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): + results.append( + self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + ) + return results + + def synthesize_finalize_jobs( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> tuple[float, List[tuple[int, np.ndarray]]]: + if not jobs_and_items: + return 0.0, [] + self._sync_device() + synth_start = time.perf_counter() + if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: + job, item = jobs_and_items[0] + batch_results = [self._synthesize_finished_audio(job, item)] + else: + batch_results = self._synthesize_finished_audio_batch(jobs_and_items) + self._sync_device() + synth_ms = (time.perf_counter() - synth_start) * 1000.0 + return float(synth_ms), batch_results diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py new file mode 100644 index 00000000..28da24ee --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Callable, Dict, List + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState + + +class WorkerPrepareExecutor: + def __init__( + self, + tts: TTS, + on_state_change: Callable[[], None] | None = None, + ) -> None: + self.coordinator = PrepareCoordinator(tts) + self.on_state_change = on_state_change + + def _notify_state_change(self) -> None: + if self.on_state_change is None: + return + try: + self.on_state_change() + except Exception: + pass + + def snapshot(self) -> Dict[str, int]: + return dict(self.coordinator.snapshot()) + + def get_max_inflight(self) -> int: + return int(self.coordinator.snapshot().get("max_inflight", 0)) + + def is_idle(self) -> bool: + return int(self.coordinator.snapshot().get("inflight", 0)) <= 0 + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + try: + return await self.coordinator.prepare_state_profiled_async(spec, prepare_submit_at) + finally: + self._notify_state_change() + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + results = await asyncio.gather( + *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] + ) + return [state for state, _, _ in results] + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + try: + return await self.coordinator.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + finally: + self._notify_state_change() + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + try: + return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) + finally: + self._notify_state_change() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_runtime.py new file mode 100644 index 00000000..de12f5e1 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_runtime.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import threading +import time +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob + + +class WorkerRuntimeBookkeepingMixin: + def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: + with self.condition: + for job in pending_jobs: + job.first_schedule_time = float(started_at) + self._runtime_update( + job.engine_request_id, + EngineStatus.GPU_PREPARING, + {"scheduler_request_id": job.request_id, "prefill_started_at": float(started_at)}, + ) + + def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.prefill_ms += delta_ms + + def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + activate_request_ids: List[str] = [] + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) + job.decode_ms += delta_ms + job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None) + + def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.finalize_wait_ms += float(delta_ms) + + def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + tasks: List[SchedulerFinalizeTask] = [] + with self.condition: + for item in items: + job = self.job_registry.get(item.request_id) + if job is not None: + self._runtime_update( + job.engine_request_id, + EngineStatus.READY_FOR_FINALIZE, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + }, + ) + tasks.append(SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at)) + self.finalize_executor.enqueue_tasks(tasks) + + def begin_finalize_execution(self, task_count: int) -> None: + self.finalize_executor.begin_execution(task_count) + + def end_finalize_execution(self, task_count: int) -> None: + self.finalize_executor.end_execution(task_count) + + def record_external_job_done(self, request_id: str) -> None: + with self.condition: + self.job_registry.mark_finished_and_remove(request_id) + self.condition.notify_all() + + def synthesize_finalize_jobs( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> tuple[float, List[tuple[int, np.ndarray]]]: + return self.finalize_executor.synthesize_finalize_jobs(jobs_and_items) + + def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: + self.completion_bridge.complete_finalize_task( + condition=self.condition, + job_registry=self.job_registry, + job=job, + item=item, + sample_rate=sample_rate, + audio_data=audio_data, + ) + + def _finalize_error(self, request_ids: List[str], error: str) -> None: + self.completion_bridge.fail_jobs( + condition=self.condition, + job_registry=self.job_registry, + request_ids=request_ids, + error=error, + ) + + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(job) + + def _notify_done_future(self, job: SchedulerPendingJob) -> None: + self.completion_bridge.notify_done_future(job) + + def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + if request_id is None or self.runtime_callbacks.update is None: + return + self.runtime_callbacks.update(request_id, status, extra) + + def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + self.completion_bridge.runtime_complete(request_id, extra) + + def _runtime_fail(self, request_id: str | None, error: str) -> None: + self.completion_bridge.runtime_fail(request_id, error) + + def _build_decode_runtime_summary_locked(self, last_event: str) -> Dict[str, Any]: + return self.decode_runtime_tracker.build_runtime_summary_locked( + legacy_shell=self.decode_legacy_shell, + last_event=str(last_event), + ) + + def _notify_decode_runtime_state(self, last_event: str) -> None: + with self.condition: + self.decode_runtime_tracker.notify_runtime_update_locked( + legacy_shell=self.decode_legacy_shell, + last_event=str(last_event), + ) + + def _record_decode_runtime_cycle(self, result: Dict[str, Any]) -> None: + with self.condition: + self.decode_runtime_tracker.record_cycle(result) + + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.decode_legacy_shell.take_pending_snapshot(wait_for_batch) + + def _take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.decode_legacy_shell.take_pending_snapshot_nonblocking(wait_for_batch) + + def has_decode_runtime_work(self) -> bool: + return self.decode_legacy_shell.has_decode_runtime_work() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py new file mode 100644 index 00000000..1e67f8d3 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +import asyncio +import threading +import time +from typing import Any, Dict, List + +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PreparedCpuStage +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerPendingJob + + +class WorkerSubmitLifecycleMixin: + def _current_decode_backlog_locked(self) -> int: + return self.decode_legacy_shell.current_backlog_locked() + + def get_micro_batch_wait_s(self) -> float: + return float(self.micro_batch_wait_s) + + def is_engine_decode_control_enabled(self) -> bool: + return bool(self.engine_decode_control_enabled) + + def get_prepare_max_inflight(self) -> int: + return int(self.prepare_executor.get_max_inflight()) + + def get_capacity_limits(self) -> Dict[str, int]: + return { + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + } + + def get_finalize_batch_policy(self) -> Dict[str, Any]: + return dict(self.finalize_executor.get_batch_policy()) + + def get_decode_runtime_counters(self) -> Dict[str, int]: + with self.condition: + return self.decode_runtime_tracker.get_counters() + + def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]: + decode_backlog = self._current_decode_backlog_locked() + finalize_pending = int(self.finalize_executor.get_pending_count()) + prepare_inflight = int(self.prepare_executor.snapshot()["inflight"]) + blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max + blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max + return ( + not blocked_decode and not blocked_finalize, + { + "decode_backlog": decode_backlog, + "finalize_pending": finalize_pending, + "prepare_inflight": prepare_inflight, + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + }, + ) + + def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]: + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + while True: + with self.condition: + allowed, snapshot = self._can_accept_submit_locked() + if allowed: + return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot + if deadline is not None and time.perf_counter() >= deadline: + raise TimeoutError( + "scheduler submit admission timeout " + f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})" + ) + self.condition.wait(timeout=self.micro_batch_wait_s) + + def _admission_snapshot_locked(self) -> Dict[str, int]: + _, snapshot = self._can_accept_submit_locked() + return snapshot + + async def submit_async( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + skip_capacity_wait: bool = False, + admission_wait_ms_override: float | None = None, + admission_snapshot_override: Dict[str, Any] | None = None, + engine_policy_wait_ms: float = 0.0, + engine_dispatch_wait_ms: float = 0.0, + enqueue_pending: bool = True, + ) -> SchedulerPendingJob: + return await asyncio.to_thread( + self.submit, + state, + speed_factor, + sample_steps, + media_type, + prepare_wall_ms, + prepare_profile_total_ms, + done_loop, + done_future, + engine_request_id, + timeout_sec, + skip_capacity_wait, + admission_wait_ms_override, + admission_snapshot_override, + engine_policy_wait_ms, + engine_dispatch_wait_ms, + enqueue_pending, + ) + + def snapshot(self) -> dict: + with self.condition: + prepare_state = self.prepare_executor.snapshot() + finalize_state = self.finalize_executor.snapshot() + shell_state = self.decode_legacy_shell.snapshot_locked() + decode_runtime_counters = self.decode_runtime_tracker.get_counters() + engine_owned_decode_state = bool(self.engine_decode_control_enabled) + active_batch_summary = shell_state.get("executor_local_active_batch") + executor_local_pending_jobs = int(shell_state.get("executor_local_pending_jobs", 0)) + executor_local_running_requests = int(shell_state.get("executor_local_running_requests", 0)) + executor_local_has_work = bool(shell_state.get("executor_local_has_work", False)) + return { + "pending_jobs": 0 if engine_owned_decode_state else executor_local_pending_jobs, + "running_requests": 0 if engine_owned_decode_state else executor_local_running_requests, + "engine_decode_control_enabled": bool(self.engine_decode_control_enabled), + "legacy_state_owner_mode": not engine_owned_decode_state, + "decode_state_owner": "engine" if engine_owned_decode_state else "worker", + "decode_runtime_has_work": False if engine_owned_decode_state else executor_local_has_work, + "executor_local_pending_jobs": executor_local_pending_jobs, + "executor_local_running_requests": executor_local_running_requests, + "executor_local_has_work": executor_local_has_work, + "decode_runtime_total_cycles": int(decode_runtime_counters.get("total_cycles", 0)), + "decode_runtime_prefill_cycles": int(decode_runtime_counters.get("prefill_cycles", 0)), + "decode_runtime_step_cycles": int(decode_runtime_counters.get("step_cycles", 0)), + "prepare_inflight": prepare_state["inflight"], + "prepare_peak_inflight": prepare_state["peak_inflight"], + "prepare_max_inflight": prepare_state.get("max_inflight", 0), + "prepare_state": dict(prepare_state), + **finalize_state, + "decode_backlog_max": self.decode_backlog_max, + "finalize_pending_max": self.finalize_pending_max, + "active_batch": {} if engine_owned_decode_state else active_batch_summary, + "executor_local_active_batch": active_batch_summary if engine_owned_decode_state else None, + "total_submitted": self.job_registry.submitted_count(), + "total_finished": self.job_registry.finished_count(), + "drained": self.is_drained(), + } + + def is_drained(self) -> bool: + with self.condition: + return ( + self.decode_legacy_shell.is_idle_locked() + and self.job_registry.is_empty() + and self.prepare_executor.is_idle() + and self.finalize_executor.is_idle() + ) + + def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: + deadline = time.perf_counter() + max(0.0, timeout_sec) + while time.perf_counter() < deadline: + if self.is_drained(): + return True + time.sleep(poll_interval_sec) + return self.is_drained() + + def submit( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + skip_capacity_wait: bool = False, + admission_wait_ms_override: float | None = None, + admission_snapshot_override: Dict[str, Any] | None = None, + engine_policy_wait_ms: float = 0.0, + engine_dispatch_wait_ms: float = 0.0, + enqueue_pending: bool = True, + ) -> SchedulerPendingJob: + if skip_capacity_wait: + with self.condition: + admission_snapshot = ( + dict(admission_snapshot_override) + if admission_snapshot_override is not None + else dict(self._admission_snapshot_locked()) + ) + admission_wait_ms = 0.0 if admission_wait_ms_override is None else float(admission_wait_ms_override) + else: + admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec) + job = SchedulerPendingJob( + request_id=state.request_id, + state=state, + done_event=threading.Event(), + done_loop=done_loop, + done_future=done_future, + enqueue_time=time.perf_counter(), + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + admission_wait_ms=float(admission_wait_ms), + engine_policy_wait_ms=float(engine_policy_wait_ms), + engine_dispatch_wait_ms=float(engine_dispatch_wait_ms), + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + engine_request_id=engine_request_id or state.request_id, + ) + with self.condition: + self.job_registry.register(job, keep_job=not self.engine_decode_control_enabled) + if enqueue_pending: + self.decode_legacy_shell.enqueue_pending_job_locked(job) + self.condition.notify_all() + if enqueue_pending: + self._notify_decode_runtime_state("submit") + self._runtime_update( + job.engine_request_id, + EngineStatus.QUEUED, + { + "scheduler_request_id": job.request_id, + "decode_admission_wait_ms": float(admission_wait_ms), + "engine_policy_wait_ms": float(engine_policy_wait_ms), + "engine_dispatch_wait_ms": float(engine_dispatch_wait_ms), + "admission_snapshot": dict(admission_snapshot), + }, + ) + return job + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_executor.prepare_state_profiled_async(spec, prepare_submit_at) + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + return await self.prepare_executor.prepare_states_batch_async(specs) + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + return await self.prepare_executor.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage)