diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py new file mode 100644 index 00000000..2c0cc9ac --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, Generator, List, Optional + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec + + +@dataclass +class RuntimeControlCallbacks: + restart: Callable[[], None] | None = None + exit: Callable[[], None] | None = None + + +@dataclass +class DirectTTSExecution: + media_type: str + streaming: bool + audio_generator: Optional[Generator[bytes, None, None]] = None + audio_bytes: Optional[bytes] = None + request_id: Optional[str] = None + + +@dataclass +class NormalizedEngineRequest: + request_id: str + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + aux_ref_audio_paths: List[str] | None = None + top_k: int = 15 + top_p: float = 1.0 + temperature: float = 1.0 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + ready_step: int = 0 + text_split_method: str = "cut5" + batch_size: int = 1 + batch_threshold: float = 0.75 + split_bucket: bool = False + speed_factor: float = 1.0 + fragment_interval: float = 0.3 + seed: int = -1 + media_type: str = "wav" + streaming_mode: bool | int = False + return_fragment: bool = False + fixed_length_chunk: bool = False + response_streaming: bool = False + parallel_infer: bool = False + sample_steps: int = 32 + super_sampling: bool = False + overlap_length: int = 2 + min_chunk_length: int = 16 + timeout_sec: float | None = None + + def to_payload(self) -> Dict[str, Any]: + return { + "request_id": self.request_id, + "text": self.text, + "text_lang": self.text_lang, + "ref_audio_path": self.ref_audio_path, + "aux_ref_audio_paths": list(self.aux_ref_audio_paths) if self.aux_ref_audio_paths else None, + "prompt_text": self.prompt_text, + "prompt_lang": self.prompt_lang, + "top_k": self.top_k, + "top_p": self.top_p, + "temperature": self.temperature, + "text_split_method": self.text_split_method, + "batch_size": self.batch_size, + "batch_threshold": self.batch_threshold, + "speed_factor": self.speed_factor, + "split_bucket": self.split_bucket, + "fragment_interval": self.fragment_interval, + "seed": self.seed, + "media_type": self.media_type, + "streaming_mode": self.streaming_mode, + "return_fragment": self.return_fragment, + "fixed_length_chunk": self.fixed_length_chunk, + "response_streaming": self.response_streaming, + "parallel_infer": self.parallel_infer, + "repetition_penalty": self.repetition_penalty, + "sample_steps": self.sample_steps, + "super_sampling": self.super_sampling, + "overlap_length": self.overlap_length, + "min_chunk_length": self.min_chunk_length, + "early_stop_num": self.early_stop_num, + "ready_step": self.ready_step, + "timeout_sec": self.timeout_sec, + } + + def to_scheduler_spec(self) -> SchedulerRequestSpec: + return SchedulerRequestSpec( + request_id=self.request_id, + ref_audio_path=Path(self.ref_audio_path), + prompt_text=self.prompt_text, + prompt_lang=self.prompt_lang, + text=self.text, + text_lang=self.text_lang, + top_k=self.top_k, + top_p=self.top_p, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty, + early_stop_num=self.early_stop_num, + ready_step=self.ready_step, + ) + + +@dataclass +class SchedulerDebugExecution: + payload: Dict[str, Any] + + +@dataclass +class SchedulerSubmitExecution: + audio_bytes: bytes + media_type: str + headers: Dict[str, str] diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py new file mode 100644 index 00000000..b6c5ca4d --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import asyncio +import threading +import time +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import EngineStatus + + +@dataclass +class EnginePolicyConfig: + enabled: bool = True + poll_wait_ms: float = 5.0 + decode_backlog_soft_max: int = 0 + finalize_pending_soft_max: int = 0 + prepare_inflight_soft_max: int = 0 + active_decode_soft_max: int = 0 + ready_for_prefill_soft_max: int = 0 + active_request_soft_max: int = 0 + + def to_dict(self) -> Dict[str, Any]: + return { + "enabled": bool(self.enabled), + "poll_wait_ms": float(self.poll_wait_ms), + "decode_backlog_soft_max": int(self.decode_backlog_soft_max), + "finalize_pending_soft_max": int(self.finalize_pending_soft_max), + "prepare_inflight_soft_max": int(self.prepare_inflight_soft_max), + "active_decode_soft_max": int(self.active_decode_soft_max), + "ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max), + "active_request_soft_max": int(self.active_request_soft_max), + } + + +@dataclass +class EngineArbiterConfig: + poll_wait_ms: float = 5.0 + decode_burst: int = 4 + prepare_aging_ms: float = 10.0 + finalize_aging_ms: float = 10.0 + + def to_dict(self) -> Dict[str, Any]: + return { + "poll_wait_ms": float(self.poll_wait_ms), + "decode_burst": int(self.decode_burst), + "prepare_aging_ms": float(self.prepare_aging_ms), + "finalize_aging_ms": float(self.finalize_aging_ms), + } + + +@dataclass +class EngineArbiterState: + total_ticks: int = 0 + total_idle_ticks: int = 0 + total_prepare_dispatches: int = 0 + total_decode_dispatches: int = 0 + total_decode_runtime_ticks: int = 0 + total_finalize_dispatches: int = 0 + decode_budget_remaining: int = 0 + last_stage: str = "idle" + last_reason: str = "init" + last_observed_at: float = 0.0 + last_policy_allowed: bool = True + + +class EnginePolicyArbiterController: + def __init__( + self, + *, + policy_config: EnginePolicyConfig, + arbiter_config: EngineArbiterConfig, + snapshot_request_registry: Callable[[], Dict[str, Any]], + get_worker_state: Callable[[], Dict[str, Any]], + snapshot_prepare_state: Callable[[], Dict[str, Any]], + snapshot_finalize_state: Callable[[], Dict[str, Any]], + snapshot_dispatch_state: Callable[[], Dict[str, Any]], + snapshot_decode_runtime_state: Callable[[], Dict[str, Any]], + snapshot_job_registry: Callable[[], Dict[str, Any]], + peek_queue_age_ms: Callable[[str], float], + merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], + ) -> None: + self.policy_config = policy_config + self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0) + self.arbiter_config = arbiter_config + self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0) + self.condition = threading.Condition() + self.state = EngineArbiterState( + decode_budget_remaining=int(self.arbiter_config.decode_burst), + last_observed_at=time.perf_counter(), + ) + self.snapshot_request_registry = snapshot_request_registry + self.get_worker_state = get_worker_state + self.snapshot_prepare_state = snapshot_prepare_state + self.snapshot_finalize_state = snapshot_finalize_state + self.snapshot_dispatch_state = snapshot_dispatch_state + self.snapshot_decode_runtime_state = snapshot_decode_runtime_state + self.snapshot_job_registry = snapshot_job_registry + self.peek_queue_age_ms = peek_queue_age_ms + self.merge_request_state_profile = merge_request_state_profile + + def snapshot_state(self) -> Dict[str, Any]: + with self.condition: + return { + "config": self.arbiter_config.to_dict(), + "total_ticks": int(self.state.total_ticks), + "total_idle_ticks": int(self.state.total_idle_ticks), + "total_prepare_dispatches": int(self.state.total_prepare_dispatches), + "total_decode_dispatches": int(self.state.total_decode_dispatches), + "total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks), + "total_finalize_dispatches": int(self.state.total_finalize_dispatches), + "decode_budget_remaining": int(self.state.decode_budget_remaining), + "last_stage": str(self.state.last_stage), + "last_reason": str(self.state.last_reason), + "last_policy_allowed": bool(self.state.last_policy_allowed), + "last_observed_at": float(self.state.last_observed_at), + } + + def notify(self) -> None: + with self.condition: + self.condition.notify_all() + + def wait(self) -> None: + with self.condition: + self.condition.wait(timeout=self.arbiter_poll_s) + + def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + with self.condition: + self.state.total_ticks += 1 + if stage == "idle": + self.state.total_idle_ticks += 1 + elif stage == "prepare": + self.state.total_prepare_dispatches += 1 + self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) + elif stage == "finalize": + self.state.total_finalize_dispatches += 1 + self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) + elif stage == "decode_dispatch": + self.state.total_decode_dispatches += 1 + elif stage == "decode_runtime": + self.state.total_decode_runtime_ticks += 1 + self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1) + self.state.last_stage = str(stage) + self.state.last_reason = str(reason) + self.state.last_policy_allowed = bool(policy_allowed) + self.state.last_observed_at = time.perf_counter() + + def build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + prepare_dispatcher_state = self.snapshot_prepare_state() + finalize_dispatcher_state = self.snapshot_finalize_state() + dispatcher_state = self.snapshot_dispatch_state() + active_requests = list(request_registry.get("active_requests", [])) + status_counts: Dict[str, int] = {} + for item in active_requests: + status = str(item.get("status", "UNKNOWN")) + status_counts[status] = status_counts.get(status, 0) + 1 + + worker_pending_jobs = int(worker_state.get("pending_jobs", 0)) + worker_decode_active_size = int(worker_state.get("running_requests", 0)) + worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0)) + worker_finalize_pending = int(worker_state.get("finalize_pending", 0)) + worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0)) + engine_decode_runtime_state = self.snapshot_decode_runtime_state() + engine_job_registry = self.snapshot_job_registry() + decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0)) + decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0)) + return { + "active_request_count": int(len(active_requests)), + "status_counts": status_counts, + "queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)), + "cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)), + "gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)), + "ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)), + "active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)), + "ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)), + "finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)), + "streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)), + "worker_pending_jobs": worker_pending_jobs, + "worker_decode_active_size": worker_decode_active_size, + "worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)), + "worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)), + "engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs, + "engine_decode_runtime_active_request_count": decode_runtime_active_size, + "engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)), + "engine_job_registry_count": int(engine_job_registry.get("job_count", 0)), + "worker_prepare_inflight": worker_prepare_inflight, + "worker_finalize_pending": worker_finalize_pending, + "worker_finalize_inflight": worker_finalize_inflight, + "engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)), + "engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)), + "engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)), + "decode_backlog": int( + decode_runtime_pending_jobs + decode_runtime_active_size + if bool(worker_state.get("engine_decode_control_enabled", False)) + else worker_pending_jobs + worker_decode_active_size + ), + } + + def build_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + counters = self.build_stage_counters(request_registry, worker_state) + config = self.policy_config.to_dict() + blocked_reasons: List[Dict[str, Any]] = [] + finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0)) + limit_checks = [ + ("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])), + ("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])), + ("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])), + ("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])), + ("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])), + ("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])), + ] + if bool(config["enabled"]): + for name, value, limit in limit_checks: + if limit > 0 and int(value) >= int(limit): + blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)}) + return { + "enabled": bool(config["enabled"]), + "allowed": (not bool(config["enabled"])) or not blocked_reasons, + "blocked_reasons": blocked_reasons, + "config": config, + "metrics": { + "active_request_count": int(counters["active_request_count"]), + "queued_request_count": int(counters["queued_request_count"]), + "ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]), + "active_decode_request_count": int(counters["active_decode_request_count"]), + "engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]), + "engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]), + "decode_backlog": int(counters["decode_backlog"]), + "prepare_inflight": int(counters["worker_prepare_inflight"]), + "finalize_pending": int(finalize_pending_total), + "engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)), + "finalize_inflight": int(counters["worker_finalize_inflight"]), + }, + "observed_at": time.perf_counter(), + } + + async def wait_for_policy_admission( + self, + *, + request_id: str | None, + timeout_sec: float | None, + ) -> tuple[float, Dict[str, Any]]: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + snapshot = self.build_policy_snapshot(request_registry, worker_state) + if not self.policy_config.enabled: + return 0.0, snapshot + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + while True: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + snapshot = self.build_policy_snapshot(request_registry, worker_state) + if snapshot["allowed"]: + wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0) + if request_id not in [None, ""]: + self.merge_request_state_profile( + str(request_id), + { + "engine_policy_wait_ms": float(wait_ms), + "engine_policy_snapshot": snapshot, + }, + ) + return wait_ms, snapshot + now = time.perf_counter() + if deadline is not None and now >= deadline: + blocked_summary = ", ".join( + f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", []) + ) + raise TimeoutError(f"engine policy admission timeout ({blocked_summary})") + await asyncio.sleep(self.policy_poll_s) + + def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) + prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0)) + finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) + decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) + decode_runtime_state = self.snapshot_decode_runtime_state() + worker_decode_has_work = bool(decode_runtime_state.get("has_work", False)) + worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False)) + worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) + worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) + prepare_age_ms = float(self.peek_queue_age_ms("prepare")) + finalize_age_ms = float(self.peek_queue_age_ms("finalize")) + decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) + decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) + policy_allowed = bool(policy_snapshot.get("allowed", True)) + if ( + worker_decode_control_enabled + and worker_decode_has_work + and policy_allowed + and decode_budget_remaining > 0 + and (worker_running_requests > 0 or worker_pending_jobs > 0) + ): + return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state + if ( + worker_decode_control_enabled + and worker_pending_jobs > 0 + and policy_allowed + and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms) + ): + return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state + if ( + decode_waiting > 0 + and policy_allowed + and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) + ): + return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state + if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): + return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): + return "finalize", "finalize_aging", policy_snapshot, worker_state + if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): + return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): + return "prepare", "prepare_aging", policy_snapshot, worker_state + if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: + return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state + if decode_waiting > 0 and policy_allowed: + return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state + if finalize_waiting > 0: + return "finalize", "finalize_fallback", policy_snapshot, worker_state + if prepare_waiting > 0: + return "prepare", "prepare_fallback", policy_snapshot, worker_state + return "idle", "no_pending_work", policy_snapshot, worker_state diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py new file mode 100644 index 00000000..111ca500 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py @@ -0,0 +1,381 @@ +from __future__ import annotations + +import asyncio +import threading +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Deque, Dict, Optional, Sequence + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState + + +@dataclass +class DefaultReferenceState: + ref_audio_path: str | None = None + updated_at: float = 0.0 + + +class ReferenceRegistry: + def __init__(self) -> None: + self._lock = threading.Lock() + self._state = DefaultReferenceState() + + def set_default(self, ref_audio_path: str) -> DefaultReferenceState: + with self._lock: + self._state = DefaultReferenceState(ref_audio_path=str(ref_audio_path), updated_at=time.time()) + return self._state + + def clear(self) -> DefaultReferenceState: + with self._lock: + self._state = DefaultReferenceState() + return self._state + + def get_default(self) -> DefaultReferenceState: + with self._lock: + return DefaultReferenceState( + ref_audio_path=self._state.ref_audio_path, + updated_at=self._state.updated_at, + ) + + +@dataclass +class ModelRegistryState: + t2s_weights_path: str + vits_weights_path: str + generation: int = 0 + t2s_generation: int = 0 + vits_generation: int = 0 + updated_at: float = field(default_factory=time.time) + + +class ModelRegistry: + def __init__(self, t2s_weights_path: str, vits_weights_path: str) -> None: + self._lock = threading.Lock() + self._state = ModelRegistryState( + t2s_weights_path=str(t2s_weights_path), + vits_weights_path=str(vits_weights_path), + ) + + def snapshot(self) -> ModelRegistryState: + with self._lock: + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + def mark_t2s_reload(self, weights_path: str) -> ModelRegistryState: + with self._lock: + self._state.t2s_weights_path = str(weights_path) + self._state.generation += 1 + self._state.t2s_generation += 1 + self._state.updated_at = time.time() + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + def mark_vits_reload(self, weights_path: str) -> ModelRegistryState: + with self._lock: + self._state.vits_weights_path = str(weights_path) + self._state.generation += 1 + self._state.vits_generation += 1 + self._state.updated_at = time.time() + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + +class EngineStatus: + NEW = "NEW" + QUEUED = "QUEUED" + VALIDATED = "VALIDATED" + CPU_PREPARING = "CPU_PREPARING" + GPU_PREPARING = "GPU_PREPARING" + READY_FOR_PREFILL = "READY_FOR_PREFILL" + ACTIVE_DECODE = "ACTIVE_DECODE" + READY_FOR_FINALIZE = "READY_FOR_FINALIZE" + FINALIZING = "FINALIZING" + STREAMING = "STREAMING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + +@dataclass +class EngineRequestState: + request_id: str + api_mode: str + backend: str + media_type: str + response_streaming: bool + submit_ts: float + deadline_ts: float | None = None + status: str = EngineStatus.NEW + updated_ts: float = 0.0 + error: str | None = None + finish_reason: str | None = None + meta: Dict[str, Any] = field(default_factory=dict) + profile: Dict[str, Any] = field(default_factory=dict) + lifecycle_timestamps: Dict[str, float] = field(default_factory=dict) + + def to_summary(self) -> Dict[str, Any]: + return { + "request_id": self.request_id, + "api_mode": self.api_mode, + "backend": self.backend, + "media_type": self.media_type, + "response_streaming": self.response_streaming, + "status": self.status, + "submit_ts": self.submit_ts, + "updated_ts": self.updated_ts, + "deadline_ts": self.deadline_ts, + "error": self.error, + "finish_reason": self.finish_reason, + "meta": dict(self.meta), + "profile": dict(self.profile), + "lifecycle_timestamps": dict(self.lifecycle_timestamps), + } + + +class EngineRequestRegistry: + def __init__(self, recent_limit: int) -> None: + self.lock = threading.Lock() + self.active_requests: Dict[str, EngineRequestState] = {} + self.recent_requests: Deque[EngineRequestState] = deque() + self.recent_limit = max(1, int(recent_limit)) + + def register( + self, + *, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + now = time.perf_counter() + state = EngineRequestState( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=bool(response_streaming), + submit_ts=now, + deadline_ts=deadline_ts, + updated_ts=now, + meta=dict(meta or {}), + lifecycle_timestamps={EngineStatus.NEW: now}, + ) + with self.lock: + self.active_requests[request_id] = state + return state + + def _move_to_recent_locked(self, state: EngineRequestState) -> None: + self.recent_requests.appendleft(state) + while len(self.recent_requests) > self.recent_limit: + self.recent_requests.pop() + + @staticmethod + def _apply_state_extra(state: EngineRequestState, extra: Optional[Dict[str, Any]]) -> None: + if not extra: + return + payload = dict(extra) + backend = payload.pop("backend", None) + if backend is not None: + state.backend = str(backend) + finish_reason = payload.pop("finish_reason", None) + if finish_reason is not None: + state.finish_reason = str(finish_reason) + error = payload.pop("error", None) + if error is not None: + state.error = str(error) + state.profile.update(payload) + + def update(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.get(request_id) + if state is None: + return + state.status = str(status) + state.updated_ts = now + state.lifecycle_timestamps[str(status)] = now + self._apply_state_extra(state, extra) + + def merge_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + if not extra: + return + now = time.perf_counter() + with self.lock: + state = self.active_requests.get(request_id) + if state is None: + for recent_state in self.recent_requests: + if recent_state.request_id == request_id: + state = recent_state + break + if state is None: + return + state.updated_ts = now + self._apply_state_extra(state, extra) + + def complete(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.COMPLETED + state.updated_ts = now + state.lifecycle_timestamps[EngineStatus.COMPLETED] = now + self._apply_state_extra(state, extra) + self._move_to_recent_locked(state) + + def fail(self, request_id: str, error: str) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.FAILED + state.updated_ts = now + state.error = str(error) + state.lifecycle_timestamps[EngineStatus.FAILED] = now + self._move_to_recent_locked(state) + + def snapshot(self) -> Dict[str, Any]: + with self.lock: + active = [state.to_summary() for state in self.active_requests.values()] + recent = [state.to_summary() for state in list(self.recent_requests)] + recent_limit = self.recent_limit + active.sort(key=lambda item: item["submit_ts"]) + return { + "active_count": len(active), + "recent_count": len(recent), + "recent_limit": recent_limit, + "active_requests": active, + "recent_requests": recent, + } + + def collect_summaries(self, request_ids: Sequence[str]) -> list[Dict[str, Any]]: + requested = set(request_ids) + results: list[Dict[str, Any]] = [] + with self.lock: + for state in self.active_requests.values(): + if state.request_id in requested: + results.append(state.to_summary()) + existing_ids = {item["request_id"] for item in results} + for state in self.recent_requests: + if state.request_id in requested and state.request_id not in existing_ids: + results.append(state.to_summary()) + results.sort(key=lambda item: item["request_id"]) + return results + + def has_active(self, request_id: str) -> bool: + with self.lock: + return request_id in self.active_requests + + +@dataclass +class SchedulerPendingJob: + request_id: str + state: T2SRequestState + done_event: threading.Event + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + enqueue_time: float + speed_factor: float + sample_steps: int + media_type: str + admission_wait_ms: float = 0.0 + engine_policy_wait_ms: float = 0.0 + engine_dispatch_wait_ms: float = 0.0 + prepare_wall_ms: float = 0.0 + prepare_profile_total_ms: float = 0.0 + first_schedule_time: float | None = None + prefill_ms: float = 0.0 + merge_ms: float = 0.0 + decode_ms: float = 0.0 + finalize_wait_ms: float = 0.0 + synth_ms: float = 0.0 + pack_ms: float = 0.0 + decode_steps: int = 0 + result_ready_time: float | None = None + result: dict | None = None + sample_rate: int | None = None + audio_data: np.ndarray | None = None + error: str | None = None + engine_request_id: str | None = None + + +class SchedulerJobRegistry: + def __init__(self, lock: threading.Lock | threading.RLock | threading.Condition) -> None: + self._lock = lock + self._job_map: Dict[str, SchedulerPendingJob] = {} + self._total_submitted = 0 + self._total_finished = 0 + + def register(self, job: SchedulerPendingJob, *, keep_job: bool = True) -> None: + with self._lock: + if keep_job: + self._job_map[job.request_id] = job + self._total_submitted += 1 + + def get(self, request_id: str) -> SchedulerPendingJob | None: + with self._lock: + return self._job_map.get(request_id) + + def pop(self, request_id: str) -> SchedulerPendingJob | None: + with self._lock: + return self._job_map.pop(request_id, None) + + def remove(self, request_id: str) -> None: + with self._lock: + self._job_map.pop(request_id, None) + + def mark_finished(self) -> None: + with self._lock: + self._total_finished += 1 + + def mark_finished_and_remove(self, request_id: str) -> None: + with self._lock: + self._job_map.pop(request_id, None) + self._total_finished += 1 + + def is_empty(self) -> bool: + with self._lock: + return not self._job_map + + def submitted_count(self) -> int: + with self._lock: + return int(self._total_submitted) + + def finished_count(self) -> int: + with self._lock: + return int(self._total_finished) + + def snapshot(self, max_request_ids: int = 32) -> Dict[str, Any]: + with self._lock: + request_ids = list(self._job_map.keys()) + return { + "job_count": int(len(request_ids)), + "request_ids": request_ids[: max(0, int(max_request_ids))], + "total_submitted": int(self._total_submitted), + "total_finished": int(self._total_finished), + } diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py new file mode 100644 index 00000000..db03a0c3 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py @@ -0,0 +1,334 @@ +from __future__ import annotations + +import asyncio +import threading +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Callable, Deque, Dict, List, Optional, Sequence + +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PreparedCpuStage +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch, T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import SchedulerPendingJob + + +class EngineTaskQueueOwner: + def __init__(self, completion_key: str = "total_completed") -> None: + self.condition = threading.Condition() + self.queue: Deque[Any] = deque() + self.total_submitted = 0 + self.total_completed = 0 + self.peak_waiting = 0 + self.completion_key = str(completion_key) + + def enqueue(self, item: Any) -> None: + with self.condition: + self.queue.append(item) + self.total_submitted += 1 + self.peak_waiting = max(self.peak_waiting, len(self.queue)) + self.condition.notify_all() + + def enqueue_many(self, items: Sequence[Any]) -> None: + if not items: + return + with self.condition: + for item in items: + self.queue.append(item) + self.total_submitted += len(items) + self.peak_waiting = max(self.peak_waiting, len(self.queue)) + self.condition.notify_all() + + def pop_left(self) -> Any | None: + with self.condition: + if not self.queue: + return None + return self.queue.popleft() + + def mark_completed(self, count: int = 1, *, notify: bool = False) -> None: + if count <= 0: + return + with self.condition: + self.total_completed += int(count) + if notify: + self.condition.notify_all() + + def has_items(self) -> bool: + with self.condition: + return bool(self.queue) + + def waiting_count(self) -> int: + with self.condition: + return int(len(self.queue)) + + def snapshot(self, *, max_request_ids: int = 16, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + with self.condition: + waiting_items = list(self.queue)[: max(0, int(max_request_ids))] + snapshot = { + "waiting_count": int(len(self.queue)), + "waiting_request_ids": [str(getattr(item, "request_id", "")) for item in waiting_items], + "peak_waiting": int(self.peak_waiting), + "total_submitted": int(self.total_submitted), + self.completion_key: int(self.total_completed), + } + if extra: + snapshot.update(dict(extra)) + return snapshot + + def peek_oldest_age_ms(self, timestamp_attr: str) -> float: + with self.condition: + if not self.queue: + return 0.0 + enqueue_time = float(getattr(self.queue[0], timestamp_attr)) + return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) + + def is_drained(self) -> bool: + with self.condition: + return not self.queue and self.total_submitted == self.total_completed + + def take_finalize_batch( + self, + *, + finalize_mode: str, + batch_max_items: int, + batch_wait_s: float, + use_vocoder: bool, + ) -> List[SchedulerFinalizeTask]: + with self.condition: + if not self.queue: + return [] + selected_tasks = [self.queue.popleft()] + if finalize_mode == "sync" or use_vocoder: + return selected_tasks + if batch_max_items <= 1: + return selected_tasks + first_task = selected_tasks[0] + oldest_age_s = max(0.0, time.perf_counter() - first_task.enqueued_time) + if len(self.queue) + 1 < batch_max_items and oldest_age_s < batch_wait_s: + self.queue.appendleft(first_task) + return [] + while len(selected_tasks) < batch_max_items: + if not self.queue: + break + matched_index = None + for index, task in enumerate(self.queue): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is None: + break + selected_tasks.append(self.queue[matched_index]) + del self.queue[matched_index] + return selected_tasks + + +@dataclass +class EngineDecodeRuntimeState: + pending_jobs: int = 0 + pending_request_ids: List[str] = field(default_factory=list) + active_request_count: int = 0 + active_request_ids: List[str] = field(default_factory=list) + prefill_done: bool = False + decode_step_index_max: int = 0 + total_cycles: int = 0 + prefill_cycles: int = 0 + step_cycles: int = 0 + has_work: bool = False + last_event: str = "init" + updated_at: float = 0.0 + + +class EngineDecodeRuntimeOwner: + def __init__( + self, + *, + get_decode_runtime_counters: Callable[[], Dict[str, int]], + get_micro_batch_wait_s: Callable[[], float], + ) -> None: + self.get_decode_runtime_counters = get_decode_runtime_counters + self.get_micro_batch_wait_s = get_micro_batch_wait_s + self.condition = threading.Condition() + self.pending_jobs: Deque[SchedulerPendingJob] = deque() + self.active_batch: T2SActiveBatch | None = None + self.state_lock = threading.Lock() + self.state = EngineDecodeRuntimeState(updated_at=time.perf_counter()) + + @staticmethod + def summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + if active_batch is None: + return {} + decode_step_index_max = 0 + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0: + decode_step_index_max = int(active_batch.step_indices.max().item()) + return { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": int(decode_step_index_max), + } + + def snapshot_pending_queue_state(self) -> Dict[str, Any]: + with self.condition: + return { + "pending_jobs": int(len(self.pending_jobs)), + "pending_request_ids": [job.request_id for job in list(self.pending_jobs)[:32]], + } + + def enqueue_pending_job(self, job: SchedulerPendingJob) -> None: + with self.condition: + self.pending_jobs.append(job) + self.condition.notify_all() + self.refresh_state("engine_decode_pending_enqueue") + + def take_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs: + return [] + if wait_for_batch: + oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) + if (time.perf_counter() - oldest_enqueue_time) < self.get_micro_batch_wait_s(): + return [] + pending_jobs = list(self.pending_jobs) + self.pending_jobs.clear() + self.refresh_state("engine_decode_pending_dequeue") + return pending_jobs + + def pending_age_ms(self) -> float: + with self.condition: + if not self.pending_jobs: + return 0.0 + enqueue_time = float(self.pending_jobs[0].enqueue_time) + return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) + + def has_pending_jobs(self) -> bool: + with self.condition: + return bool(self.pending_jobs) + + def get_active_batch(self) -> T2SActiveBatch | None: + return self.active_batch + + def set_active_batch(self, active_batch: T2SActiveBatch | None) -> None: + self.active_batch = active_batch + + def active_batch_summary(self) -> Dict[str, Any]: + return self.summarize_active_batch(self.active_batch) + + def refresh_state(self, last_event: str) -> None: + pending_state = self.snapshot_pending_queue_state() + active_batch_summary = self.active_batch_summary() + worker_decode_counters = self.get_decode_runtime_counters() + with self.state_lock: + self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) + self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) + self.state.active_request_count = int(active_batch_summary.get("request_count", 0)) + self.state.active_request_ids = list(active_batch_summary.get("request_ids", []))[:32] + self.state.prefill_done = bool(active_batch_summary.get("prefill_done", False)) + self.state.decode_step_index_max = int(active_batch_summary.get("decode_step_index_max", 0)) + self.state.total_cycles = int(worker_decode_counters.get("total_cycles", 0)) + self.state.prefill_cycles = int(worker_decode_counters.get("prefill_cycles", 0)) + self.state.step_cycles = int(worker_decode_counters.get("step_cycles", 0)) + self.state.has_work = bool(pending_state.get("pending_jobs", 0) or active_batch_summary.get("request_count", 0)) + self.state.last_event = str(last_event) + self.state.updated_at = float(time.perf_counter()) + + def update_from_worker_snapshot(self, snapshot: Dict[str, Any]) -> None: + if not snapshot: + return + pending_state = self.snapshot_pending_queue_state() + with self.state_lock: + self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) + self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) + self.state.active_request_count = int(snapshot.get("active_request_count", 0)) + self.state.active_request_ids = list(snapshot.get("active_request_ids", []))[:32] + self.state.prefill_done = bool(snapshot.get("prefill_done", False)) + self.state.decode_step_index_max = int(snapshot.get("decode_step_index_max", 0)) + self.state.total_cycles = int(snapshot.get("total_cycles", 0)) + self.state.prefill_cycles = int(snapshot.get("prefill_cycles", 0)) + self.state.step_cycles = int(snapshot.get("step_cycles", 0)) + self.state.has_work = bool( + pending_state.get("pending_jobs", 0) + or snapshot.get("active_request_count", 0) + or snapshot.get("has_work", False) + ) + self.state.last_event = str(snapshot.get("last_event", "unknown")) + self.state.updated_at = float(snapshot.get("updated_at", time.perf_counter())) + + def snapshot_state(self) -> Dict[str, Any]: + pending_state = self.snapshot_pending_queue_state() + active_batch_summary = self.active_batch_summary() + worker_decode_counters = self.get_decode_runtime_counters() + with self.state_lock: + return { + "pending_jobs": int(pending_state.get("pending_jobs", self.state.pending_jobs)), + "pending_request_ids": list(pending_state.get("pending_request_ids", self.state.pending_request_ids)), + "active_request_count": int(active_batch_summary.get("request_count", self.state.active_request_count)), + "active_request_ids": list(active_batch_summary.get("request_ids", self.state.active_request_ids)), + "prefill_done": bool(active_batch_summary.get("prefill_done", self.state.prefill_done)), + "decode_step_index_max": int(active_batch_summary.get("decode_step_index_max", self.state.decode_step_index_max)), + "total_cycles": int(worker_decode_counters.get("total_cycles", 0)), + "prefill_cycles": int(worker_decode_counters.get("prefill_cycles", 0)), + "step_cycles": int(worker_decode_counters.get("step_cycles", 0)), + "has_work": bool( + pending_state.get("pending_jobs", 0) + or active_batch_summary.get("request_count", self.state.active_request_count) + or self.state.has_work + ), + "last_event": str(self.state.last_event), + "updated_at": float(self.state.updated_at), + } + + +@dataclass +class SchedulerFinalizeTask: + request_id: str + item: T2SFinishedItem + enqueued_time: float + + +@dataclass +class EngineDispatchTask: + request_id: str + state: T2SRequestState + speed_factor: float + sample_steps: int + media_type: str + prepare_wall_ms: float + prepare_profile_total_ms: float + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + engine_request_id: str | None + timeout_sec: float | None + enqueue_time: float + worker_job: SchedulerPendingJob | None = None + engine_policy_wait_ms: float = 0.0 + engine_dispatch_wait_ms: float = 0.0 + engine_policy_snapshot: Dict[str, Any] | None = None + error: str | None = None + + +@dataclass +class EngineGpuPrepareTask: + request_id: str + cpu_stage: PreparedCpuStage + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + engine_request_id: str | None + enqueue_time: float + queue_wait_ms: float = 0.0 + error: str | None = None + + +@dataclass +class EngineFinalizeQueueState: + waiting_count: int + waiting_request_ids: List[str] + peak_waiting: int + total_submitted: int + total_completed: int + + +@dataclass +class RuntimeStateCallbacks: + update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None + complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None + fail: Callable[[str, str], None] | None = None + decode_runtime_update: Callable[[Dict[str, Any]], None] | None = None diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py index 3a124f4e..ac1adac5 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py @@ -1,1150 +1,63 @@ -from __future__ import annotations - -import asyncio -import os -import threading -import time -import uuid -from collections import deque -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Callable, Deque, Dict, List, Optional, Sequence, Tuple, Union - -import numpy as np -import torch - -from GPT_SoVITS.TTS_infer_pack.TTS import TTS -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState - - -@dataclass -class RuntimeControlCallbacks: - restart: Callable[[], None] | None = None - exit: Callable[[], None] | None = None - - -@dataclass -class DefaultReferenceState: - ref_audio_path: str | None = None - updated_at: float = 0.0 - - -class ReferenceRegistry: - def __init__(self) -> None: - self._lock = threading.Lock() - self._state = DefaultReferenceState() - - def set_default(self, ref_audio_path: str) -> DefaultReferenceState: - with self._lock: - self._state = DefaultReferenceState(ref_audio_path=str(ref_audio_path), updated_at=time.time()) - return self._state - - def clear(self) -> DefaultReferenceState: - with self._lock: - self._state = DefaultReferenceState() - return self._state - - def get_default(self) -> DefaultReferenceState: - with self._lock: - return DefaultReferenceState( - ref_audio_path=self._state.ref_audio_path, - updated_at=self._state.updated_at, - ) - - -@dataclass -class ModelRegistryState: - t2s_weights_path: str - vits_weights_path: str - generation: int = 0 - t2s_generation: int = 0 - vits_generation: int = 0 - updated_at: float = field(default_factory=time.time) - - -class ModelRegistry: - def __init__(self, t2s_weights_path: str, vits_weights_path: str) -> None: - self._lock = threading.Lock() - self._state = ModelRegistryState( - t2s_weights_path=str(t2s_weights_path), - vits_weights_path=str(vits_weights_path), - ) - - def snapshot(self) -> ModelRegistryState: - with self._lock: - return ModelRegistryState( - t2s_weights_path=self._state.t2s_weights_path, - vits_weights_path=self._state.vits_weights_path, - generation=self._state.generation, - t2s_generation=self._state.t2s_generation, - vits_generation=self._state.vits_generation, - updated_at=self._state.updated_at, - ) - - def mark_t2s_reload(self, weights_path: str) -> ModelRegistryState: - with self._lock: - self._state.t2s_weights_path = str(weights_path) - self._state.generation += 1 - self._state.t2s_generation += 1 - self._state.updated_at = time.time() - return ModelRegistryState( - t2s_weights_path=self._state.t2s_weights_path, - vits_weights_path=self._state.vits_weights_path, - generation=self._state.generation, - t2s_generation=self._state.t2s_generation, - vits_generation=self._state.vits_generation, - updated_at=self._state.updated_at, - ) - - def mark_vits_reload(self, weights_path: str) -> ModelRegistryState: - with self._lock: - self._state.vits_weights_path = str(weights_path) - self._state.generation += 1 - self._state.vits_generation += 1 - self._state.updated_at = time.time() - return ModelRegistryState( - t2s_weights_path=self._state.t2s_weights_path, - vits_weights_path=self._state.vits_weights_path, - generation=self._state.generation, - t2s_generation=self._state.t2s_generation, - vits_generation=self._state.vits_generation, - updated_at=self._state.updated_at, - ) - - -@dataclass -class DirectTTSExecution: - media_type: str - streaming: bool - audio_generator: Optional[Generator[bytes, None, None]] = None - audio_bytes: Optional[bytes] = None - request_id: Optional[str] = None - - -@dataclass -class NormalizedEngineRequest: - request_id: str - text: str - text_lang: str - ref_audio_path: str - prompt_lang: str - prompt_text: str = "" - aux_ref_audio_paths: List[str] | None = None - top_k: int = 15 - top_p: float = 1.0 - temperature: float = 1.0 - repetition_penalty: float = 1.35 - early_stop_num: int = -1 - ready_step: int = 0 - text_split_method: str = "cut5" - batch_size: int = 1 - batch_threshold: float = 0.75 - split_bucket: bool = False - speed_factor: float = 1.0 - fragment_interval: float = 0.3 - seed: int = -1 - media_type: str = "wav" - streaming_mode: bool | int = False - return_fragment: bool = False - fixed_length_chunk: bool = False - response_streaming: bool = False - parallel_infer: bool = False - sample_steps: int = 32 - super_sampling: bool = False - overlap_length: int = 2 - min_chunk_length: int = 16 - timeout_sec: float | None = None - - def to_payload(self) -> Dict[str, Any]: - return { - "request_id": self.request_id, - "text": self.text, - "text_lang": self.text_lang, - "ref_audio_path": self.ref_audio_path, - "aux_ref_audio_paths": list(self.aux_ref_audio_paths) if self.aux_ref_audio_paths else None, - "prompt_text": self.prompt_text, - "prompt_lang": self.prompt_lang, - "top_k": self.top_k, - "top_p": self.top_p, - "temperature": self.temperature, - "text_split_method": self.text_split_method, - "batch_size": self.batch_size, - "batch_threshold": self.batch_threshold, - "speed_factor": self.speed_factor, - "split_bucket": self.split_bucket, - "fragment_interval": self.fragment_interval, - "seed": self.seed, - "media_type": self.media_type, - "streaming_mode": self.streaming_mode, - "return_fragment": self.return_fragment, - "fixed_length_chunk": self.fixed_length_chunk, - "response_streaming": self.response_streaming, - "parallel_infer": self.parallel_infer, - "repetition_penalty": self.repetition_penalty, - "sample_steps": self.sample_steps, - "super_sampling": self.super_sampling, - "overlap_length": self.overlap_length, - "min_chunk_length": self.min_chunk_length, - "early_stop_num": self.early_stop_num, - "ready_step": self.ready_step, - "timeout_sec": self.timeout_sec, - } - - def to_scheduler_spec(self) -> SchedulerRequestSpec: - return SchedulerRequestSpec( - request_id=self.request_id, - ref_audio_path=Path(self.ref_audio_path), - prompt_text=self.prompt_text, - prompt_lang=self.prompt_lang, - text=self.text, - text_lang=self.text_lang, - top_k=self.top_k, - top_p=self.top_p, - temperature=self.temperature, - repetition_penalty=self.repetition_penalty, - early_stop_num=self.early_stop_num, - ready_step=self.ready_step, - ) - - -@dataclass -class SchedulerDebugExecution: - payload: Dict[str, Any] - - -@dataclass -class SchedulerSubmitExecution: - audio_bytes: bytes - media_type: str - headers: Dict[str, str] - - -@dataclass -class EnginePolicyConfig: - enabled: bool = True - poll_wait_ms: float = 5.0 - decode_backlog_soft_max: int = 0 - finalize_pending_soft_max: int = 0 - prepare_inflight_soft_max: int = 0 - active_decode_soft_max: int = 0 - ready_for_prefill_soft_max: int = 0 - active_request_soft_max: int = 0 - - def to_dict(self) -> Dict[str, Any]: - return { - "enabled": bool(self.enabled), - "poll_wait_ms": float(self.poll_wait_ms), - "decode_backlog_soft_max": int(self.decode_backlog_soft_max), - "finalize_pending_soft_max": int(self.finalize_pending_soft_max), - "prepare_inflight_soft_max": int(self.prepare_inflight_soft_max), - "active_decode_soft_max": int(self.active_decode_soft_max), - "ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max), - "active_request_soft_max": int(self.active_request_soft_max), - } - - -@dataclass -class EngineArbiterConfig: - poll_wait_ms: float = 5.0 - decode_burst: int = 4 - prepare_aging_ms: float = 10.0 - finalize_aging_ms: float = 10.0 - - def to_dict(self) -> Dict[str, Any]: - return { - "poll_wait_ms": float(self.poll_wait_ms), - "decode_burst": int(self.decode_burst), - "prepare_aging_ms": float(self.prepare_aging_ms), - "finalize_aging_ms": float(self.finalize_aging_ms), - } - - -class EngineStatus: - NEW = "NEW" - QUEUED = "QUEUED" - VALIDATED = "VALIDATED" - CPU_PREPARING = "CPU_PREPARING" - GPU_PREPARING = "GPU_PREPARING" - READY_FOR_PREFILL = "READY_FOR_PREFILL" - ACTIVE_DECODE = "ACTIVE_DECODE" - READY_FOR_FINALIZE = "READY_FOR_FINALIZE" - FINALIZING = "FINALIZING" - STREAMING = "STREAMING" - COMPLETED = "COMPLETED" - FAILED = "FAILED" - - -@dataclass -class EngineRequestState: - request_id: str - api_mode: str - backend: str - media_type: str - response_streaming: bool - submit_ts: float - deadline_ts: float | None = None - status: str = EngineStatus.NEW - updated_ts: float = 0.0 - error: str | None = None - finish_reason: str | None = None - meta: Dict[str, Any] = field(default_factory=dict) - profile: Dict[str, Any] = field(default_factory=dict) - lifecycle_timestamps: Dict[str, float] = field(default_factory=dict) - - def to_summary(self) -> Dict[str, Any]: - return { - "request_id": self.request_id, - "api_mode": self.api_mode, - "backend": self.backend, - "media_type": self.media_type, - "response_streaming": self.response_streaming, - "status": self.status, - "submit_ts": self.submit_ts, - "updated_ts": self.updated_ts, - "deadline_ts": self.deadline_ts, - "error": self.error, - "finish_reason": self.finish_reason, - "meta": dict(self.meta), - "profile": dict(self.profile), - "lifecycle_timestamps": dict(self.lifecycle_timestamps), - } - - -class EngineRequestRegistry: - def __init__(self, recent_limit: int) -> None: - self.lock = threading.Lock() - self.active_requests: Dict[str, EngineRequestState] = {} - self.recent_requests: Deque[EngineRequestState] = deque() - self.recent_limit = max(1, int(recent_limit)) - - def register( - self, - *, - request_id: str, - api_mode: str, - backend: str, - media_type: str, - response_streaming: bool, - deadline_ts: float | None = None, - meta: Optional[Dict[str, Any]] = None, - ) -> EngineRequestState: - now = time.perf_counter() - state = EngineRequestState( - request_id=request_id, - api_mode=api_mode, - backend=backend, - media_type=media_type, - response_streaming=bool(response_streaming), - submit_ts=now, - deadline_ts=deadline_ts, - updated_ts=now, - meta=dict(meta or {}), - lifecycle_timestamps={EngineStatus.NEW: now}, - ) - with self.lock: - self.active_requests[request_id] = state - return state - - def _move_to_recent_locked(self, state: EngineRequestState) -> None: - self.recent_requests.appendleft(state) - while len(self.recent_requests) > self.recent_limit: - self.recent_requests.pop() - - @staticmethod - def _apply_state_extra(state: EngineRequestState, extra: Optional[Dict[str, Any]]) -> None: - if not extra: - return - payload = dict(extra) - backend = payload.pop("backend", None) - if backend is not None: - state.backend = str(backend) - finish_reason = payload.pop("finish_reason", None) - if finish_reason is not None: - state.finish_reason = str(finish_reason) - error = payload.pop("error", None) - if error is not None: - state.error = str(error) - state.profile.update(payload) - - def update(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: - now = time.perf_counter() - with self.lock: - state = self.active_requests.get(request_id) - if state is None: - return - state.status = str(status) - state.updated_ts = now - state.lifecycle_timestamps[str(status)] = now - self._apply_state_extra(state, extra) - - def merge_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - if not extra: - return - now = time.perf_counter() - with self.lock: - state = self.active_requests.get(request_id) - if state is None: - for recent_state in self.recent_requests: - if recent_state.request_id == request_id: - state = recent_state - break - if state is None: - return - state.updated_ts = now - self._apply_state_extra(state, extra) - - def complete(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - now = time.perf_counter() - with self.lock: - state = self.active_requests.pop(request_id, None) - if state is None: - return - state.status = EngineStatus.COMPLETED - state.updated_ts = now - state.lifecycle_timestamps[EngineStatus.COMPLETED] = now - self._apply_state_extra(state, extra) - self._move_to_recent_locked(state) - - def fail(self, request_id: str, error: str) -> None: - now = time.perf_counter() - with self.lock: - state = self.active_requests.pop(request_id, None) - if state is None: - return - state.status = EngineStatus.FAILED - state.updated_ts = now - state.error = str(error) - state.lifecycle_timestamps[EngineStatus.FAILED] = now - self._move_to_recent_locked(state) - - def snapshot(self) -> Dict[str, Any]: - with self.lock: - active = [state.to_summary() for state in self.active_requests.values()] - recent = [state.to_summary() for state in list(self.recent_requests)] - recent_limit = self.recent_limit - active.sort(key=lambda item: item["submit_ts"]) - return { - "active_count": len(active), - "recent_count": len(recent), - "recent_limit": recent_limit, - "active_requests": active, - "recent_requests": recent, - } - - def collect_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: - requested = set(request_ids) - results: List[Dict[str, Any]] = [] - with self.lock: - for state in self.active_requests.values(): - if state.request_id in requested: - results.append(state.to_summary()) - existing_ids = {item["request_id"] for item in results} - for state in self.recent_requests: - if state.request_id in requested and state.request_id not in existing_ids: - results.append(state.to_summary()) - results.sort(key=lambda item: item["request_id"]) - return results - - def has_active(self, request_id: str) -> bool: - with self.lock: - return request_id in self.active_requests - - -@dataclass -class SchedulerPendingJob: - request_id: str - state: T2SRequestState - done_event: threading.Event - done_loop: asyncio.AbstractEventLoop | None - done_future: asyncio.Future | None - enqueue_time: float - speed_factor: float - sample_steps: int - media_type: str - admission_wait_ms: float = 0.0 - engine_policy_wait_ms: float = 0.0 - engine_dispatch_wait_ms: float = 0.0 - prepare_wall_ms: float = 0.0 - prepare_profile_total_ms: float = 0.0 - first_schedule_time: float | None = None - prefill_ms: float = 0.0 - merge_ms: float = 0.0 - decode_ms: float = 0.0 - finalize_wait_ms: float = 0.0 - synth_ms: float = 0.0 - pack_ms: float = 0.0 - decode_steps: int = 0 - result_ready_time: float | None = None - result: dict | None = None - sample_rate: int | None = None - audio_data: np.ndarray | None = None - error: str | None = None - engine_request_id: str | None = None - - -class SchedulerJobRegistry: - def __init__(self, lock: threading.Lock | threading.RLock | threading.Condition) -> None: - self._lock = lock - self._job_map: Dict[str, SchedulerPendingJob] = {} - self._total_submitted = 0 - self._total_finished = 0 - - def register(self, job: SchedulerPendingJob, *, keep_job: bool = True) -> None: - with self._lock: - if keep_job: - self._job_map[job.request_id] = job - self._total_submitted += 1 - - def get(self, request_id: str) -> SchedulerPendingJob | None: - with self._lock: - return self._job_map.get(request_id) - - def pop(self, request_id: str) -> SchedulerPendingJob | None: - with self._lock: - return self._job_map.pop(request_id, None) - - def remove(self, request_id: str) -> None: - with self._lock: - self._job_map.pop(request_id, None) - - def mark_finished(self) -> None: - with self._lock: - self._total_finished += 1 - - def mark_finished_and_remove(self, request_id: str) -> None: - with self._lock: - self._job_map.pop(request_id, None) - self._total_finished += 1 - - def is_empty(self) -> bool: - with self._lock: - return not self._job_map - - def submitted_count(self) -> int: - with self._lock: - return int(self._total_submitted) - - def finished_count(self) -> int: - with self._lock: - return int(self._total_finished) - - def snapshot(self, max_request_ids: int = 32) -> Dict[str, Any]: - with self._lock: - request_ids = list(self._job_map.keys()) - return { - "job_count": int(len(request_ids)), - "request_ids": request_ids[: max(0, int(max_request_ids))], - "total_submitted": int(self._total_submitted), - "total_finished": int(self._total_finished), - } - - -class EngineTaskQueueOwner: - def __init__(self, completion_key: str = "total_completed") -> None: - self.condition = threading.Condition() - self.queue: Deque[Any] = deque() - self.total_submitted = 0 - self.total_completed = 0 - self.peak_waiting = 0 - self.completion_key = str(completion_key) - - def enqueue(self, item: Any) -> None: - with self.condition: - self.queue.append(item) - self.total_submitted += 1 - self.peak_waiting = max(self.peak_waiting, len(self.queue)) - self.condition.notify_all() - - def enqueue_many(self, items: Sequence[Any]) -> None: - if not items: - return - with self.condition: - for item in items: - self.queue.append(item) - self.total_submitted += len(items) - self.peak_waiting = max(self.peak_waiting, len(self.queue)) - self.condition.notify_all() - - def pop_left(self) -> Any | None: - with self.condition: - if not self.queue: - return None - return self.queue.popleft() - - def mark_completed(self, count: int = 1, *, notify: bool = False) -> None: - if count <= 0: - return - with self.condition: - self.total_completed += int(count) - if notify: - self.condition.notify_all() - - def has_items(self) -> bool: - with self.condition: - return bool(self.queue) - - def waiting_count(self) -> int: - with self.condition: - return int(len(self.queue)) - - def snapshot(self, *, max_request_ids: int = 16, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - with self.condition: - waiting_items = list(self.queue)[: max(0, int(max_request_ids))] - snapshot = { - "waiting_count": int(len(self.queue)), - "waiting_request_ids": [str(getattr(item, "request_id", "")) for item in waiting_items], - "peak_waiting": int(self.peak_waiting), - "total_submitted": int(self.total_submitted), - self.completion_key: int(self.total_completed), - } - if extra: - snapshot.update(dict(extra)) - return snapshot - - def peek_oldest_age_ms(self, timestamp_attr: str) -> float: - with self.condition: - if not self.queue: - return 0.0 - enqueue_time = float(getattr(self.queue[0], timestamp_attr)) - return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) - - def is_drained(self) -> bool: - with self.condition: - return not self.queue and self.total_submitted == self.total_completed - - def take_finalize_batch( - self, - *, - finalize_mode: str, - batch_max_items: int, - batch_wait_s: float, - use_vocoder: bool, - ) -> List[SchedulerFinalizeTask]: - with self.condition: - if not self.queue: - return [] - selected_tasks = [self.queue.popleft()] - if finalize_mode == "sync" or use_vocoder: - return selected_tasks - if batch_max_items <= 1: - return selected_tasks - first_task = selected_tasks[0] - oldest_age_s = max(0.0, time.perf_counter() - first_task.enqueued_time) - if len(self.queue) + 1 < batch_max_items and oldest_age_s < batch_wait_s: - self.queue.appendleft(first_task) - return [] - while len(selected_tasks) < batch_max_items: - if not self.queue: - break - matched_index = None - for index, task in enumerate(self.queue): - if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: - matched_index = index - break - if matched_index is None: - break - selected_tasks.append(self.queue[matched_index]) - del self.queue[matched_index] - return selected_tasks - - -class EnginePolicyArbiterController: - def __init__( - self, - *, - policy_config: EnginePolicyConfig, - arbiter_config: EngineArbiterConfig, - snapshot_request_registry: Callable[[], Dict[str, Any]], - get_worker_state: Callable[[], Dict[str, Any]], - snapshot_prepare_state: Callable[[], Dict[str, Any]], - snapshot_finalize_state: Callable[[], Dict[str, Any]], - snapshot_dispatch_state: Callable[[], Dict[str, Any]], - snapshot_decode_runtime_state: Callable[[], Dict[str, Any]], - snapshot_job_registry: Callable[[], Dict[str, Any]], - peek_queue_age_ms: Callable[[str], float], - merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], - ) -> None: - self.policy_config = policy_config - self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0) - self.arbiter_config = arbiter_config - self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0) - self.condition = threading.Condition() - self.state = EngineArbiterState( - decode_budget_remaining=int(self.arbiter_config.decode_burst), - last_observed_at=time.perf_counter(), - ) - self.snapshot_request_registry = snapshot_request_registry - self.get_worker_state = get_worker_state - self.snapshot_prepare_state = snapshot_prepare_state - self.snapshot_finalize_state = snapshot_finalize_state - self.snapshot_dispatch_state = snapshot_dispatch_state - self.snapshot_decode_runtime_state = snapshot_decode_runtime_state - self.snapshot_job_registry = snapshot_job_registry - self.peek_queue_age_ms = peek_queue_age_ms - self.merge_request_state_profile = merge_request_state_profile - - def snapshot_state(self) -> Dict[str, Any]: - with self.condition: - return { - "config": self.arbiter_config.to_dict(), - "total_ticks": int(self.state.total_ticks), - "total_idle_ticks": int(self.state.total_idle_ticks), - "total_prepare_dispatches": int(self.state.total_prepare_dispatches), - "total_decode_dispatches": int(self.state.total_decode_dispatches), - "total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks), - "total_finalize_dispatches": int(self.state.total_finalize_dispatches), - "decode_budget_remaining": int(self.state.decode_budget_remaining), - "last_stage": str(self.state.last_stage), - "last_reason": str(self.state.last_reason), - "last_policy_allowed": bool(self.state.last_policy_allowed), - "last_observed_at": float(self.state.last_observed_at), - } - - def notify(self) -> None: - with self.condition: - self.condition.notify_all() - - def wait(self) -> None: - with self.condition: - self.condition.wait(timeout=self.arbiter_poll_s) - - def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: - with self.condition: - self.state.total_ticks += 1 - if stage == "idle": - self.state.total_idle_ticks += 1 - elif stage == "prepare": - self.state.total_prepare_dispatches += 1 - self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) - elif stage == "finalize": - self.state.total_finalize_dispatches += 1 - self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) - elif stage == "decode_dispatch": - self.state.total_decode_dispatches += 1 - elif stage == "decode_runtime": - self.state.total_decode_runtime_ticks += 1 - self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1) - self.state.last_stage = str(stage) - self.state.last_reason = str(reason) - self.state.last_policy_allowed = bool(policy_allowed) - self.state.last_observed_at = time.perf_counter() - - def build_stage_counters( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - prepare_dispatcher_state = self.snapshot_prepare_state() - finalize_dispatcher_state = self.snapshot_finalize_state() - dispatcher_state = self.snapshot_dispatch_state() - active_requests = list(request_registry.get("active_requests", [])) - status_counts: Dict[str, int] = {} - for item in active_requests: - status = str(item.get("status", "UNKNOWN")) - status_counts[status] = status_counts.get(status, 0) + 1 - - worker_pending_jobs = int(worker_state.get("pending_jobs", 0)) - worker_decode_active_size = int(worker_state.get("running_requests", 0)) - worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0)) - worker_finalize_pending = int(worker_state.get("finalize_pending", 0)) - worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0)) - engine_decode_runtime_state = self.snapshot_decode_runtime_state() - engine_job_registry = self.snapshot_job_registry() - decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0)) - decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0)) - return { - "active_request_count": int(len(active_requests)), - "status_counts": status_counts, - "queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)), - "cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)), - "gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)), - "ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)), - "active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)), - "ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)), - "finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)), - "streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)), - "worker_pending_jobs": worker_pending_jobs, - "worker_decode_active_size": worker_decode_active_size, - "worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)), - "worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)), - "engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs, - "engine_decode_runtime_active_request_count": decode_runtime_active_size, - "engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)), - "engine_job_registry_count": int(engine_job_registry.get("job_count", 0)), - "worker_prepare_inflight": worker_prepare_inflight, - "worker_finalize_pending": worker_finalize_pending, - "worker_finalize_inflight": worker_finalize_inflight, - "engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)), - "engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)), - "engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)), - "decode_backlog": int( - decode_runtime_pending_jobs + decode_runtime_active_size - if bool(worker_state.get("engine_decode_control_enabled", False)) - else worker_pending_jobs + worker_decode_active_size - ), - } - - def build_policy_snapshot( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - counters = self.build_stage_counters(request_registry, worker_state) - config = self.policy_config.to_dict() - blocked_reasons: List[Dict[str, Any]] = [] - finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0)) - limit_checks = [ - ("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])), - ("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])), - ("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])), - ("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])), - ("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])), - ("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])), - ] - if bool(config["enabled"]): - for name, value, limit in limit_checks: - if limit > 0 and int(value) >= int(limit): - blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)}) - return { - "enabled": bool(config["enabled"]), - "allowed": (not bool(config["enabled"])) or not blocked_reasons, - "blocked_reasons": blocked_reasons, - "config": config, - "metrics": { - "active_request_count": int(counters["active_request_count"]), - "queued_request_count": int(counters["queued_request_count"]), - "ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]), - "active_decode_request_count": int(counters["active_decode_request_count"]), - "engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]), - "engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]), - "decode_backlog": int(counters["decode_backlog"]), - "prepare_inflight": int(counters["worker_prepare_inflight"]), - "finalize_pending": int(finalize_pending_total), - "engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)), - "finalize_inflight": int(counters["worker_finalize_inflight"]), - }, - "observed_at": time.perf_counter(), - } - - async def wait_for_policy_admission( - self, - *, - request_id: str | None, - timeout_sec: float | None, - ) -> tuple[float, Dict[str, Any]]: - request_registry = self.snapshot_request_registry() - worker_state = self.get_worker_state() - snapshot = self.build_policy_snapshot(request_registry, worker_state) - if not self.policy_config.enabled: - return 0.0, snapshot - start = time.perf_counter() - deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) - while True: - request_registry = self.snapshot_request_registry() - worker_state = self.get_worker_state() - snapshot = self.build_policy_snapshot(request_registry, worker_state) - if snapshot["allowed"]: - wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0) - if request_id not in [None, ""]: - self.merge_request_state_profile( - str(request_id), - { - "engine_policy_wait_ms": float(wait_ms), - "engine_policy_snapshot": snapshot, - }, - ) - return wait_ms, snapshot - now = time.perf_counter() - if deadline is not None and now >= deadline: - blocked_summary = ", ".join( - f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", []) - ) - raise TimeoutError(f"engine policy admission timeout ({blocked_summary})") - await asyncio.sleep(self.policy_poll_s) - - def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: - request_registry = self.snapshot_request_registry() - worker_state = self.get_worker_state() - policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) - prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0)) - finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) - decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) - decode_runtime_state = self.snapshot_decode_runtime_state() - worker_decode_has_work = bool(decode_runtime_state.get("has_work", False)) - worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False)) - worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) - worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) - prepare_age_ms = float(self.peek_queue_age_ms("prepare")) - finalize_age_ms = float(self.peek_queue_age_ms("finalize")) - decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) - decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) - policy_allowed = bool(policy_snapshot.get("allowed", True)) - if ( - worker_decode_control_enabled - and worker_decode_has_work - and policy_allowed - and decode_budget_remaining > 0 - and (worker_running_requests > 0 or worker_pending_jobs > 0) - ): - return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state - if ( - worker_decode_control_enabled - and worker_pending_jobs > 0 - and policy_allowed - and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms) - ): - return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state - if ( - decode_waiting > 0 - and policy_allowed - and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) - ): - return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state - if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): - return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state - if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): - return "finalize", "finalize_aging", policy_snapshot, worker_state - if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): - return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state - if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): - return "prepare", "prepare_aging", policy_snapshot, worker_state - if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: - return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state - if decode_waiting > 0 and policy_allowed: - return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state - if finalize_waiting > 0: - return "finalize", "finalize_fallback", policy_snapshot, worker_state - if prepare_waiting > 0: - return "prepare", "prepare_fallback", policy_snapshot, worker_state - return "idle", "no_pending_work", policy_snapshot, worker_state - - -class EngineDecodeRuntimeOwner: - def __init__( - self, - *, - get_decode_runtime_counters: Callable[[], Dict[str, int]], - get_micro_batch_wait_s: Callable[[], float], - ) -> None: - self.get_decode_runtime_counters = get_decode_runtime_counters - self.get_micro_batch_wait_s = get_micro_batch_wait_s - self.condition = threading.Condition() - self.pending_jobs: Deque[SchedulerPendingJob] = deque() - self.active_batch: T2SActiveBatch | None = None - self.state_lock = threading.Lock() - self.state = EngineDecodeRuntimeState(updated_at=time.perf_counter()) - - @staticmethod - def summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: - if active_batch is None: - return {} - decode_step_index_max = 0 - if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0: - decode_step_index_max = int(active_batch.step_indices.max().item()) - return { - "request_count": int(len(active_batch.request_ids)), - "request_ids": list(active_batch.request_ids), - "prefill_done": bool(active_batch.prefill_done), - "decode_step_index_max": int(decode_step_index_max), - } - - def snapshot_pending_queue_state(self) -> Dict[str, Any]: - with self.condition: - return { - "pending_jobs": int(len(self.pending_jobs)), - "pending_request_ids": [job.request_id for job in list(self.pending_jobs)[:32]], - } - - def enqueue_pending_job(self, job: SchedulerPendingJob) -> None: - with self.condition: - self.pending_jobs.append(job) - self.condition.notify_all() - self.refresh_state("engine_decode_pending_enqueue") - - def take_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - with self.condition: - if not self.pending_jobs: - return [] - if wait_for_batch: - oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) - if (time.perf_counter() - oldest_enqueue_time) < self.get_micro_batch_wait_s(): - return [] - pending_jobs = list(self.pending_jobs) - self.pending_jobs.clear() - self.refresh_state("engine_decode_pending_dequeue") - return pending_jobs - - def pending_age_ms(self) -> float: - with self.condition: - if not self.pending_jobs: - return 0.0 - enqueue_time = float(self.pending_jobs[0].enqueue_time) - return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) - - def has_pending_jobs(self) -> bool: - with self.condition: - return bool(self.pending_jobs) - - def get_active_batch(self) -> T2SActiveBatch | None: - return self.active_batch - - def set_active_batch(self, active_batch: T2SActiveBatch | None) -> None: - self.active_batch = active_batch - - def active_batch_summary(self) -> Dict[str, Any]: - return self.summarize_active_batch(self.active_batch) - - def refresh_state(self, last_event: str) -> None: - pending_state = self.snapshot_pending_queue_state() - active_batch_summary = self.active_batch_summary() - worker_decode_counters = self.get_decode_runtime_counters() - with self.state_lock: - self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) - self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) - self.state.active_request_count = int(active_batch_summary.get("request_count", 0)) - self.state.active_request_ids = list(active_batch_summary.get("request_ids", []))[:32] - self.state.prefill_done = bool(active_batch_summary.get("prefill_done", False)) - self.state.decode_step_index_max = int(active_batch_summary.get("decode_step_index_max", 0)) - self.state.total_cycles = int(worker_decode_counters.get("total_cycles", 0)) - self.state.prefill_cycles = int(worker_decode_counters.get("prefill_cycles", 0)) - self.state.step_cycles = int(worker_decode_counters.get("step_cycles", 0)) - self.state.has_work = bool(pending_state.get("pending_jobs", 0) or active_batch_summary.get("request_count", 0)) - self.state.last_event = str(last_event) - self.state.updated_at = float(time.perf_counter()) - - def update_from_worker_snapshot(self, snapshot: Dict[str, Any]) -> None: - if not snapshot: - return - pending_state = self.snapshot_pending_queue_state() - with self.state_lock: - self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) - self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) - self.state.active_request_count = int(snapshot.get("active_request_count", 0)) - self.state.active_request_ids = list(snapshot.get("active_request_ids", []))[:32] - self.state.prefill_done = bool(snapshot.get("prefill_done", False)) - self.state.decode_step_index_max = int(snapshot.get("decode_step_index_max", 0)) - self.state.total_cycles = int(snapshot.get("total_cycles", 0)) - self.state.prefill_cycles = int(snapshot.get("prefill_cycles", 0)) - self.state.step_cycles = int(snapshot.get("step_cycles", 0)) - self.state.has_work = bool( - pending_state.get("pending_jobs", 0) - or snapshot.get("active_request_count", 0) - or snapshot.get("has_work", False) - ) - self.state.last_event = str(snapshot.get("last_event", "unknown")) - self.state.updated_at = float(snapshot.get("updated_at", time.perf_counter())) - - def snapshot_state(self) -> Dict[str, Any]: - pending_state = self.snapshot_pending_queue_state() - active_batch_summary = self.active_batch_summary() - worker_decode_counters = self.get_decode_runtime_counters() - with self.state_lock: - return { - "pending_jobs": int(pending_state.get("pending_jobs", self.state.pending_jobs)), - "pending_request_ids": list(pending_state.get("pending_request_ids", self.state.pending_request_ids)), - "active_request_count": int(active_batch_summary.get("request_count", self.state.active_request_count)), - "active_request_ids": list(active_batch_summary.get("request_ids", self.state.active_request_ids)), - "prefill_done": bool(active_batch_summary.get("prefill_done", self.state.prefill_done)), - "decode_step_index_max": int( - active_batch_summary.get("decode_step_index_max", self.state.decode_step_index_max) - ), - "total_cycles": int(worker_decode_counters.get("total_cycles", 0)), - "prefill_cycles": int(worker_decode_counters.get("prefill_cycles", 0)), - "step_cycles": int(worker_decode_counters.get("step_cycles", 0)), - "has_work": bool( - pending_state.get("pending_jobs", 0) - or active_batch_summary.get("request_count", self.state.active_request_count) - or self.state.has_work - ), - "last_event": str(self.state.last_event), - "updated_at": float(self.state.updated_at), - } - -@dataclass -class SchedulerFinalizeTask: - request_id: str - item: T2SFinishedItem - enqueued_time: float - - -@dataclass -class EngineDispatchTask: - request_id: str - state: T2SRequestState - speed_factor: float - sample_steps: int - media_type: str - prepare_wall_ms: float - prepare_profile_total_ms: float - done_loop: asyncio.AbstractEventLoop | None - done_future: asyncio.Future | None - engine_request_id: str | None - timeout_sec: float | None - enqueue_time: float - worker_job: SchedulerPendingJob | None = None - engine_policy_wait_ms: float = 0.0 - engine_dispatch_wait_ms: float = 0.0 - engine_policy_snapshot: Dict[str, Any] | None = None - error: str | None = None - - -@dataclass -class EngineGpuPrepareTask: - request_id: str - cpu_stage: PreparedCpuStage - done_loop: asyncio.AbstractEventLoop | None - done_future: asyncio.Future | None - engine_request_id: str | None - enqueue_time: float - queue_wait_ms: float = 0.0 - error: str | None = None - - -@dataclass -class EngineFinalizeQueueState: - waiting_count: int - waiting_request_ids: List[str] - peak_waiting: int - total_submitted: int - total_completed: int - - -@dataclass -class EngineArbiterState: - total_ticks: int = 0 - total_idle_ticks: int = 0 - total_prepare_dispatches: int = 0 - total_decode_dispatches: int = 0 - total_decode_runtime_ticks: int = 0 - total_finalize_dispatches: int = 0 - decode_budget_remaining: int = 0 - last_stage: str = "idle" - last_reason: str = "init" - last_observed_at: float = 0.0 - last_policy_allowed: bool = True - - -@dataclass -class EngineDecodeRuntimeState: - pending_jobs: int = 0 - pending_request_ids: List[str] = field(default_factory=list) - active_request_count: int = 0 - active_request_ids: List[str] = field(default_factory=list) - prefill_done: bool = False - decode_step_index_max: int = 0 - total_cycles: int = 0 - prefill_cycles: int = 0 - step_cycles: int = 0 - has_work: bool = False - last_event: str = "init" - updated_at: float = 0.0 - - -@dataclass -class RuntimeStateCallbacks: - update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None - complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None - fail: Callable[[str, str], None] | None = None - decode_runtime_update: Callable[[Dict[str, Any]], None] | None = None - - +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_models import ( + DirectTTSExecution, + NormalizedEngineRequest, + RuntimeControlCallbacks, + SchedulerDebugExecution, + SchedulerSubmitExecution, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_policy import ( + EngineArbiterConfig, + EngineArbiterState, + EnginePolicyArbiterController, + EnginePolicyConfig, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import ( + DefaultReferenceState, + EngineRequestRegistry, + EngineRequestState, + EngineStatus, + ModelRegistry, + ModelRegistryState, + ReferenceRegistry, + SchedulerJobRegistry, + SchedulerPendingJob, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_runtime import ( + EngineDecodeRuntimeOwner, + EngineDecodeRuntimeState, + EngineDispatchTask, + EngineFinalizeQueueState, + EngineGpuPrepareTask, + EngineTaskQueueOwner, + RuntimeStateCallbacks, + SchedulerFinalizeTask, +) + +__all__ = [ + "DefaultReferenceState", + "DirectTTSExecution", + "EngineArbiterConfig", + "EngineArbiterState", + "EngineDecodeRuntimeOwner", + "EngineDecodeRuntimeState", + "EngineDispatchTask", + "EngineFinalizeQueueState", + "EngineGpuPrepareTask", + "EnginePolicyArbiterController", + "EnginePolicyConfig", + "EngineRequestRegistry", + "EngineRequestState", + "EngineStatus", + "EngineTaskQueueOwner", + "ModelRegistry", + "ModelRegistryState", + "NormalizedEngineRequest", + "ReferenceRegistry", + "RuntimeControlCallbacks", + "RuntimeStateCallbacks", + "SchedulerDebugExecution", + "SchedulerFinalizeTask", + "SchedulerJobRegistry", + "SchedulerPendingJob", + "SchedulerSubmitExecution", +] diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py index 04d9090f..934ccf52 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py @@ -1,905 +1,25 @@ from __future__ import annotations -import asyncio import os import threading -import time -from collections import deque -from typing import Any, Callable, Deque, Dict, List, Optional - -import numpy as np -import torch +from typing import Callable, List from GPT_SoVITS.TTS_infer_pack.TTS import TTS -from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState, decode_one_step, merge_active_batches, run_prefill_active_batch, run_scheduler_continuous -from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, RuntimeStateCallbacks, SchedulerFinalizeTask, SchedulerJobRegistry, SchedulerPendingJob - - -class WorkerPrepareExecutor: - def __init__( - self, - tts: TTS, - on_state_change: Callable[[], None] | None = None, - ) -> None: - self.coordinator = PrepareCoordinator(tts) - self.on_state_change = on_state_change - - def _notify_state_change(self) -> None: - if self.on_state_change is None: - return - try: - self.on_state_change() - except Exception: - pass - - def snapshot(self) -> Dict[str, int]: - return dict(self.coordinator.snapshot()) - - def get_max_inflight(self) -> int: - return int(self.coordinator.snapshot().get("max_inflight", 0)) - - def is_idle(self) -> bool: - return int(self.coordinator.snapshot().get("inflight", 0)) <= 0 - - async def prepare_state_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> tuple[T2SRequestState, float, float]: - try: - return await self.coordinator.prepare_state_profiled_async(spec, prepare_submit_at) - finally: - self._notify_state_change() - - async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: - results = await asyncio.gather( - *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] - ) - return [state for state, _, _ in results] - - async def prepare_cpu_stage_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> PreparedCpuStage: - try: - return await self.coordinator.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) - finally: - self._notify_state_change() - - async def prepare_gpu_stage_profiled_async( - self, - cpu_stage: PreparedCpuStage, - ) -> tuple[T2SRequestState, float, float]: - try: - return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) - finally: - self._notify_state_change() - - -class WorkerFinalizeExecutor: - def __init__( - self, - tts: TTS, - on_state_change: Callable[[], None] | None = None, - external_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, - ) -> None: - self.tts = tts - self.on_state_change = on_state_change - self.external_submit = external_submit - self.condition = threading.Condition() - self.pending_tasks: Deque[SchedulerFinalizeTask] = deque() - self.pending_peak = 0 - self.inflight = 0 - self.inflight_peak = 0 - self.worker_count = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) - self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() - self.batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) - self.batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) - - def _notify_state_change(self) -> None: - if self.on_state_change is None: - return - try: - self.on_state_change() - except Exception: - pass - - def get_worker_count(self) -> int: - return int(self.worker_count) - - def get_batch_policy(self) -> Dict[str, Any]: - return { - "finalize_mode": str(self.finalize_mode), - "finalize_batch_max_items": int(self.batch_max_items), - "finalize_batch_wait_s": float(self.batch_wait_s), - } - - def get_pending_count(self) -> int: - with self.condition: - return int(len(self.pending_tasks)) - - def snapshot(self) -> Dict[str, Any]: - with self.condition: - return { - "finalize_pending": int(len(self.pending_tasks)), - "finalize_pending_peak": int(self.pending_peak), - "finalize_inflight": int(self.inflight), - "finalize_inflight_peak": int(self.inflight_peak), - "finalize_workers": int(self.worker_count), - "finalize_mode": str(self.finalize_mode), - "finalize_batch_max_items": int(self.batch_max_items), - "finalize_batch_wait_ms": float(self.batch_wait_s * 1000.0), - } - - def is_idle(self) -> bool: - with self.condition: - return self.inflight <= 0 and not self.pending_tasks - - def enqueue_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: - if not tasks: - return - if self.external_submit is not None: - self.external_submit(tasks) - self._notify_state_change() - return - with self.condition: - for task in tasks: - self.pending_tasks.append(task) - self.pending_peak = max(self.pending_peak, len(self.pending_tasks)) - self.condition.notify_all() - self._notify_state_change() - - def begin_execution(self, task_count: int) -> None: - if task_count <= 0: - return - with self.condition: - self.inflight += int(task_count) - self.inflight_peak = max(self.inflight_peak, self.inflight) - self.condition.notify_all() - self._notify_state_change() - - def end_execution(self, task_count: int) -> None: - with self.condition: - self.inflight = max(0, self.inflight - int(task_count)) - self.condition.notify_all() - self._notify_state_change() - - def take_task_batch_blocking(self) -> List[SchedulerFinalizeTask]: - with self.condition: - while not self.pending_tasks: - self.condition.wait() - selected_tasks = [self.pending_tasks.popleft()] - if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: - self.inflight += len(selected_tasks) - self.inflight_peak = max(self.inflight_peak, self.inflight) - self._notify_state_change() - return selected_tasks - batch_deadline = time.perf_counter() + self.batch_wait_s - while len(selected_tasks) < self.batch_max_items: - if not self.pending_tasks: - remaining = batch_deadline - time.perf_counter() - if remaining <= 0: - break - self.condition.wait(timeout=remaining) - continue - first_task = selected_tasks[0] - matched_index = None - for index, task in enumerate(self.pending_tasks): - if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: - matched_index = index - break - if matched_index is not None: - selected_tasks.append(self.pending_tasks[matched_index]) - del self.pending_tasks[matched_index] - continue - remaining = batch_deadline - time.perf_counter() - if remaining <= 0: - break - self.condition.wait(timeout=remaining) - self.inflight += len(selected_tasks) - self.inflight_peak = max(self.inflight_peak, self.inflight) - self._notify_state_change() - return selected_tasks - - def _sync_device(self) -> None: - try: - device_str = str(self.tts.configs.device) - if device_str.startswith("cuda") and torch.cuda.is_available(): - torch.cuda.synchronize(self.tts.configs.device) - elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): - torch.mps.synchronize() - except Exception: - pass - - def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: - audio_fragment = self.tts.synthesize_audio_request_local( - semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), - phones=job.state.phones.detach().clone().unsqueeze(0), - prompt_semantic=job.state.prompt_semantic.detach().clone(), - prompt_phones=job.state.prompt_phones.detach().clone(), - refer_spec=( - job.state.refer_spec[0].detach().clone(), - None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), - ), - raw_audio=job.state.raw_audio.detach().clone(), - raw_sr=int(job.state.raw_sr), - speed=float(job.speed_factor), - sample_steps=int(job.sample_steps), - ) - output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] - return self.tts.audio_postprocess( - audio=[[audio_fragment]], - sr=int(output_sr), - batch_index_list=None, - speed_factor=float(job.speed_factor), - split_bucket=False, - fragment_interval=0.0, - super_sampling=False, - ) - - def _synthesize_finished_audio_batch( - self, - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], - ) -> List[tuple[int, np.ndarray]]: - semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] - phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] - refer_specs = [] - speeds = [] - sample_steps_list = [] - for job, _ in jobs_and_items: - refer_specs.append( - ( - job.state.refer_spec[0].detach().clone(), - None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), - ) - ) - speeds.append(float(job.speed_factor)) - sample_steps_list.append(int(job.sample_steps)) - audio_fragments = self.tts.synthesize_audio_requests_local_batched( - semantic_tokens_list=semantic_tokens_list, - phones_list=phones_list, - refer_specs=refer_specs, - speeds=speeds, - sample_steps_list=sample_steps_list, - ) - output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] - results: List[tuple[int, np.ndarray]] = [] - for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): - results.append( - self.tts.audio_postprocess( - audio=[[audio_fragment]], - sr=int(output_sr), - batch_index_list=None, - speed_factor=float(job.speed_factor), - split_bucket=False, - fragment_interval=0.0, - super_sampling=False, - ) - ) - return results - - def synthesize_finalize_jobs( - self, - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], - ) -> tuple[float, List[tuple[int, np.ndarray]]]: - if not jobs_and_items: - return 0.0, [] - self._sync_device() - synth_start = time.perf_counter() - if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: - job, item = jobs_and_items[0] - batch_results = [self._synthesize_finished_audio(job, item)] - else: - batch_results = self._synthesize_finished_audio_batch(jobs_and_items) - self._sync_device() - synth_ms = (time.perf_counter() - synth_start) * 1000.0 - return float(synth_ms), batch_results - - -class WorkerCompletionBridge: - def __init__(self, runtime_callbacks: RuntimeStateCallbacks | None = None) -> None: - self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() - - @staticmethod - def _resolve_done_future(job: SchedulerPendingJob) -> None: - future = job.done_future - if future is None or future.done(): - return - future.set_result(job) - - def notify_done_future(self, job: SchedulerPendingJob) -> None: - if job.done_loop is None or job.done_future is None: - return - try: - job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) - except RuntimeError: - pass - - def runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: - if request_id is None or self.runtime_callbacks.complete is None: - return - self.runtime_callbacks.complete(request_id, extra) - - def runtime_fail(self, request_id: str | None, error: str) -> None: - if request_id is None or self.runtime_callbacks.fail is None: - return - self.runtime_callbacks.fail(request_id, error) - - @staticmethod - def build_completed_job_result( - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - audio_data: np.ndarray, - finished_at: float | None = None, - ) -> Dict[str, Any]: - finished_at = float(time.perf_counter() if finished_at is None else finished_at) - queue_wait_ms = 0.0 - if job.first_schedule_time is not None: - queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) - worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) - worker_residual_ms = max( - 0.0, - worker_total_ms - - queue_wait_ms - - job.prefill_ms - - job.merge_ms - - job.decode_ms - - job.finalize_wait_ms - - job.synth_ms, - ) - worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) - job.sample_rate = int(sample_rate) - job.audio_data = audio_data - job.result_ready_time = finished_at - prepare_profile = dict(job.state.prepare_profile) - result = { - "request_id": item.request_id, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "finish_reason": item.finish_reason, - "decode_admission_wait_ms": float(job.admission_wait_ms), - "engine_policy_wait_ms": float(job.engine_policy_wait_ms), - "engine_dispatch_wait_ms": float(job.engine_dispatch_wait_ms), - "prepare_ms": job.prepare_wall_ms, - "prepare_wall_ms": job.prepare_wall_ms, - "prepare_profile_total_ms": job.prepare_profile_total_ms, - "prepare_profile": prepare_profile, - "queue_wait_ms": queue_wait_ms, - "prefill_ms": job.prefill_ms, - "merge_ms": job.merge_ms, - "decode_ms": job.decode_ms, - "finalize_wait_ms": job.finalize_wait_ms, - "synth_ms": job.synth_ms, - "worker_residual_ms": worker_residual_ms, - "worker_other_ms": worker_other_ms, - "worker_total_ms": worker_total_ms, - "decode_steps": int(job.decode_steps), - "sample_rate": int(sample_rate), - "media_type": job.media_type, - } - job.result = result - return result - - @staticmethod - def build_runtime_complete_payload( - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - ) -> Dict[str, Any]: - return { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "sample_rate": int(sample_rate), - "worker_profile": dict(job.result or {}), - } - - def complete_job( - self, - job: SchedulerPendingJob, - *, - runtime_request_id: str | None, - runtime_extra: Optional[Dict[str, Any]] = None, - remove_job: Callable[[], None] | None = None, - on_job_finished: Callable[[], None] | None = None, - notify_waiters: Callable[[], None] | None = None, - ) -> None: - job.done_event.set() - self.notify_done_future(job) - if remove_job is not None: - remove_job() - if on_job_finished is not None: - on_job_finished() - if notify_waiters is not None: - notify_waiters() - self.runtime_complete(runtime_request_id, runtime_extra) - - def fail_job( - self, - job: SchedulerPendingJob, - *, - error: str, - remove_job: Callable[[], None] | None = None, - on_job_finished: Callable[[], None] | None = None, - notify_waiters: Callable[[], None] | None = None, - ) -> None: - job.error = str(error) - job.done_event.set() - self.notify_done_future(job) - if remove_job is not None: - remove_job() - if on_job_finished is not None: - on_job_finished() - if notify_waiters is not None: - notify_waiters() - self.runtime_fail(job.engine_request_id, str(error)) - - def complete_finalize_task( - self, - *, - condition: threading.Condition, - job_registry: SchedulerJobRegistry, - job: SchedulerPendingJob, - item: T2SFinishedItem, - sample_rate: int, - audio_data: np.ndarray, - ) -> None: - runtime_extra: Optional[Dict[str, Any]] = None - with condition: - if job_registry.get(item.request_id) is not job: - return - self.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) - runtime_extra = self.build_runtime_complete_payload(job, item, sample_rate=sample_rate) - self.complete_job( - job, - runtime_request_id=job.engine_request_id, - runtime_extra=runtime_extra, - on_job_finished=lambda: job_registry.mark_finished_and_remove(item.request_id), - notify_waiters=condition.notify_all, - ) - - def fail_jobs( - self, - *, - condition: threading.Condition, - job_registry: SchedulerJobRegistry, - request_ids: List[str], - error: str, - ) -> None: - if not request_ids: - return - with condition: - for request_id in request_ids: - job = job_registry.get(request_id) - if job is None: - continue - self.fail_job( - job, - error=error, - on_job_finished=lambda rid=request_id: job_registry.mark_finished_and_remove(rid), - ) - condition.notify_all() - - -class WorkerDecodeExecutor: - def __init__(self, tts: TTS, max_steps: int) -> None: - self.tts = tts - self.max_steps = int(max_steps) - - def _sync_device(self) -> None: - try: - device_str = str(self.tts.configs.device) - if device_str.startswith("cuda") and torch.cuda.is_available(): - torch.cuda.synchronize(self.tts.configs.device) - elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): - torch.mps.synchronize() - except Exception: - pass - - def execute_prefill_merge( - self, - *, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], - add_prefill_time: Callable[[List[str], float], None] | None, - add_merge_time: Callable[[List[str], float], None] | None, - enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, - finalize_error: Callable[[List[str], str], None] | None, - ) -> Dict[str, Any]: - if not pending_jobs: - return { - "executed": False, - "active_batch": active_batch, - "pending_jobs": [], - "prefill_elapsed_s": 0.0, - "merge_elapsed_s": 0.0, - "finished_items": [], - "error": None, - "error_request_ids": [], - } - admitted_finished: List[T2SFinishedItem] = [] - prefill_elapsed_s = 0.0 - merge_elapsed_s = 0.0 - error: str | None = None - error_request_ids: List[str] = [] - try: - self._sync_device() - prefill_start = time.perf_counter() - mark_prefill_started(pending_jobs, prefill_start) - admitted_active_batch, admitted_finished = run_prefill_active_batch( - self.tts.t2s_model.model, - [job.state for job in pending_jobs], - max_steps=self.max_steps, - ) - self._sync_device() - prefill_elapsed_s = time.perf_counter() - prefill_start - if add_prefill_time is not None: - add_prefill_time([job.request_id for job in pending_jobs], prefill_elapsed_s) - if enqueue_finished is not None: - enqueue_finished(admitted_finished) - merge_start = time.perf_counter() - active_batch = merge_active_batches( - self.tts.t2s_model.model, - active_batch, - admitted_active_batch, - ) - merge_elapsed_s = time.perf_counter() - merge_start - if add_merge_time is not None: - add_merge_time( - [] if active_batch is None else list(active_batch.request_ids), - merge_elapsed_s, - ) - except Exception as exc: - error = str(exc) - error_request_ids = [job.request_id for job in pending_jobs] - if finalize_error is not None: - finalize_error(error_request_ids, error) - return { - "executed": True, - "active_batch": active_batch, - "pending_jobs": list(pending_jobs), - "prefill_elapsed_s": float(prefill_elapsed_s), - "merge_elapsed_s": float(merge_elapsed_s), - "finished_items": list(admitted_finished), - "error": error, - "error_request_ids": error_request_ids, - } - - def execute_decode_step( - self, - *, - active_batch: Optional[T2SActiveBatch], - add_decode_time: Callable[[List[str], float], None] | None, - enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, - finalize_error: Callable[[List[str], str], None] | None, - ) -> Dict[str, Any]: - if active_batch is None: - return { - "executed": False, - "active_batch": None, - "request_ids": [], - "decode_elapsed_s": 0.0, - "finished_items": [], - "error": None, - "error_request_ids": [], - } - active_request_ids: List[str] = [] - step_finished: List[T2SFinishedItem] = [] - decode_elapsed_s = 0.0 - error: str | None = None - error_request_ids: List[str] = [] - try: - active_request_ids = [state.request_id for state in active_batch.states] - self._sync_device() - decode_start = time.perf_counter() - active_batch, step_finished = decode_one_step( - self.tts.t2s_model.model, - active_batch, - max_steps=self.max_steps, - ) - self._sync_device() - decode_elapsed_s = time.perf_counter() - decode_start - if add_decode_time is not None: - add_decode_time(active_request_ids, decode_elapsed_s) - if enqueue_finished is not None: - enqueue_finished(step_finished) - except Exception as exc: - error = str(exc) - error_request_ids = list(active_request_ids) - if finalize_error is not None: - finalize_error(error_request_ids, error) - active_batch = None - return { - "executed": True, - "active_batch": active_batch, - "request_ids": active_request_ids, - "decode_elapsed_s": float(decode_elapsed_s), - "finished_items": list(step_finished), - "error": error, - "error_request_ids": error_request_ids, - } - - def execute_decode_cycle( - self, - *, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], - add_prefill_time: Callable[[List[str], float], None] | None, - add_merge_time: Callable[[List[str], float], None] | None, - add_decode_time: Callable[[List[str], float], None] | None, - enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, - finalize_error: Callable[[List[str], str], None] | None, - ) -> Dict[str, Any]: - result = { - "executed": False, - "prefill_merge_executed": False, - "decode_step_executed": False, - "active_batch": active_batch, - "prefill_phase": {}, - "decode_phase": {}, - } - prefill_phase = self.execute_prefill_merge( - pending_jobs=list(pending_jobs), - active_batch=result["active_batch"], - mark_prefill_started=mark_prefill_started, - add_prefill_time=add_prefill_time, - add_merge_time=add_merge_time, - enqueue_finished=enqueue_finished, - finalize_error=finalize_error, - ) - prefill_executed = bool(prefill_phase.get("executed", False)) - result["prefill_phase"] = prefill_phase - result["active_batch"] = prefill_phase.get("active_batch") - if prefill_executed: - result["executed"] = True - result["prefill_merge_executed"] = True - decode_phase = self.execute_decode_step( - active_batch=result["active_batch"], - add_decode_time=add_decode_time, - enqueue_finished=enqueue_finished, - finalize_error=finalize_error, - ) - decode_executed = bool(decode_phase.get("executed", False)) - result["decode_phase"] = decode_phase - result["active_batch"] = decode_phase.get("active_batch") - if decode_executed: - result["executed"] = True - result["decode_step_executed"] = True - return result - - -class WorkerDecodeLegacyShell: - def __init__(self, condition: threading.Condition, micro_batch_wait_s: float) -> None: - self.condition = condition - self.micro_batch_wait_s = float(micro_batch_wait_s) - self.pending_jobs: List[SchedulerPendingJob] = [] - self.active_batch: T2SActiveBatch | None = None - - @staticmethod - def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any] | None: - if active_batch is None: - return None - return { - "request_count": int(len(active_batch.request_ids)), - "request_ids": list(active_batch.request_ids), - "prefill_done": bool(active_batch.prefill_done), - "decode_step_index_max": ( - int(active_batch.step_indices.max().item()) - if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0 - else 0 - ), - } - - def current_backlog_locked(self) -> int: - running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids) - return int(len(self.pending_jobs) + running_requests) - - def enqueue_pending_job_locked(self, job: SchedulerPendingJob) -> None: - self.pending_jobs.append(job) - - def snapshot_locked(self) -> Dict[str, Any]: - active_batch_summary = self._summarize_active_batch(self.active_batch) - executor_local_pending_jobs = int(len(self.pending_jobs)) - executor_local_running_requests = 0 if self.active_batch is None else int(len(self.active_batch.request_ids)) - executor_local_has_work = bool(self.pending_jobs or self.active_batch is not None) - return { - "executor_local_pending_jobs": executor_local_pending_jobs, - "executor_local_running_requests": executor_local_running_requests, - "executor_local_has_work": executor_local_has_work, - "executor_local_active_batch": active_batch_summary, - } - - def is_idle_locked(self) -> bool: - return self.active_batch is None and not self.pending_jobs - - def take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - with self.condition: - if not self.pending_jobs and self.active_batch is None: - self.condition.wait(timeout=self.micro_batch_wait_s) - elif wait_for_batch and self.pending_jobs: - self.condition.wait(timeout=self.micro_batch_wait_s) - if not self.pending_jobs: - return [] - pending = list(self.pending_jobs) - self.pending_jobs.clear() - return pending - - def take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - with self.condition: - if not self.pending_jobs: - return [] - if wait_for_batch: - oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) - if (time.perf_counter() - oldest_enqueue_time) < self.micro_batch_wait_s: - return [] - pending = list(self.pending_jobs) - self.pending_jobs.clear() - return pending - - def has_decode_runtime_work(self) -> bool: - with self.condition: - return bool(self.pending_jobs or self.active_batch is not None) - - def build_runtime_summary_locked(self, *, total_cycles: int, prefill_cycles: int, step_cycles: int, last_event: str) -> Dict[str, Any]: - active_request_ids = [] if self.active_batch is None else list(self.active_batch.request_ids) - decode_step_index_max = 0 - prefill_done = False - if self.active_batch is not None: - prefill_done = bool(self.active_batch.prefill_done) - if self.active_batch.step_indices is not None and self.active_batch.step_indices.numel() > 0: - decode_step_index_max = int(self.active_batch.step_indices.max().item()) - return { - "pending_jobs": int(len(self.pending_jobs)), - "active_request_count": int(len(active_request_ids)), - "active_request_ids": active_request_ids[:32], - "prefill_done": bool(prefill_done), - "decode_step_index_max": int(decode_step_index_max), - "total_cycles": int(total_cycles), - "prefill_cycles": int(prefill_cycles), - "step_cycles": int(step_cycles), - "has_work": bool(self.pending_jobs or self.active_batch is not None), - "last_event": str(last_event), - "updated_at": float(time.perf_counter()), - } - - def run_prefill_merge_once_nonblocking( - self, - *, - external_pending_jobs: Optional[List[SchedulerPendingJob]], - external_active_batch: Optional[T2SActiveBatch], - execute_prefill_merge: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], - ) -> Dict[str, Any]: - pending_jobs = ( - list(external_pending_jobs) - if external_pending_jobs is not None - else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) - ) - active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch - result = execute_prefill_merge(pending_jobs, active_batch) - if external_pending_jobs is None: - with self.condition: - self.active_batch = result.get("active_batch") - self.condition.notify_all() - return result - - def run_decode_step_once_nonblocking( - self, - *, - external_active_batch: Optional[T2SActiveBatch], - execute_decode_step: Callable[[Optional[T2SActiveBatch]], Dict[str, Any]], - ) -> Dict[str, Any]: - active_batch = self.active_batch if external_active_batch is None else external_active_batch - result = execute_decode_step(active_batch) - if external_active_batch is None: - with self.condition: - self.active_batch = result.get("active_batch") - self.condition.notify_all() - return result - - def run_decode_cycle_nonblocking( - self, - *, - external_pending_jobs: Optional[List[SchedulerPendingJob]], - external_active_batch: Optional[T2SActiveBatch], - execute_decode_cycle: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], - on_cycle_executed: Callable[[Dict[str, Any]], None] | None, - ) -> Dict[str, Any]: - pending_jobs = ( - list(external_pending_jobs) - if external_pending_jobs is not None - else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) - ) - active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch - result = execute_decode_cycle(pending_jobs, active_batch) - if external_pending_jobs is None: - with self.condition: - self.active_batch = result.get("active_batch") - self.condition.notify_all() - if result.get("executed") and on_cycle_executed is not None: - on_cycle_executed(result) - return result - - def run_loop( - self, - *, - run_decode_cycle_nonblocking: Callable[[], Dict[str, Any]], - ) -> None: - while True: - executed = run_decode_cycle_nonblocking() - if executed.get("executed"): - continue - wait_for_batch = self.active_batch is None - pending_jobs = self.take_pending_snapshot(wait_for_batch=wait_for_batch) - if pending_jobs: - with self.condition: - self.pending_jobs = pending_jobs + self.pending_jobs - self.condition.notify_all() - continue - time.sleep(self.micro_batch_wait_s) - - -class WorkerDecodeRuntimeTracker: - def __init__( - self, - runtime_callbacks: RuntimeStateCallbacks | None = None, - ) -> None: - self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() - self.total_cycles = 0 - self.prefill_cycles = 0 - self.step_cycles = 0 - - def get_counters(self) -> Dict[str, int]: - return { - "total_cycles": int(self.total_cycles), - "prefill_cycles": int(self.prefill_cycles), - "step_cycles": int(self.step_cycles), - } - - def record_cycle(self, result: Dict[str, Any]) -> None: - if not bool(result.get("executed")): - return - self.total_cycles += 1 - if bool(result.get("prefill_merge_executed")): - self.prefill_cycles += 1 - if bool(result.get("decode_step_executed")): - self.step_cycles += 1 - - def build_runtime_summary_locked( - self, - *, - legacy_shell: WorkerDecodeLegacyShell, - last_event: str, - ) -> Dict[str, Any]: - return legacy_shell.build_runtime_summary_locked( - total_cycles=int(self.total_cycles), - prefill_cycles=int(self.prefill_cycles), - step_cycles=int(self.step_cycles), - last_event=str(last_event), - ) - - def notify_runtime_update_locked( - self, - *, - legacy_shell: WorkerDecodeLegacyShell, - last_event: str, - ) -> None: - if self.runtime_callbacks.decode_runtime_update is None: - return - snapshot = self.build_runtime_summary_locked( - legacy_shell=legacy_shell, - last_event=last_event, - ) - self.runtime_callbacks.decode_runtime_update(snapshot) - - -class UnifiedSchedulerWorker: +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerFinalizeTask, SchedulerJobRegistry +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_completion import WorkerCompletionBridge +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_decode import WorkerDecodeExecutor, WorkerDecodeLegacyShell, WorkerDecodeRuntimeTracker +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_execution import WorkerExecutionMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_finalize import WorkerFinalizeExecutor +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_prepare import WorkerPrepareExecutor +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_runtime import WorkerRuntimeBookkeepingMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_submit import WorkerSubmitLifecycleMixin + + +class UnifiedSchedulerWorker( + WorkerSubmitLifecycleMixin, + WorkerRuntimeBookkeepingMixin, + WorkerExecutionMixin, +): def __init__( self, tts: TTS, @@ -949,562 +69,3 @@ class UnifiedSchedulerWorker: def _notify_worker_state_change(self) -> None: with self.condition: self.condition.notify_all() - - def _current_decode_backlog_locked(self) -> int: - return self.decode_legacy_shell.current_backlog_locked() - - def get_micro_batch_wait_s(self) -> float: - return float(self.micro_batch_wait_s) - - def is_engine_decode_control_enabled(self) -> bool: - return bool(self.engine_decode_control_enabled) - - def get_prepare_max_inflight(self) -> int: - return int(self.prepare_executor.get_max_inflight()) - - def get_capacity_limits(self) -> Dict[str, int]: - return { - "decode_backlog_max": int(self.decode_backlog_max), - "finalize_pending_max": int(self.finalize_pending_max), - } - - def get_finalize_batch_policy(self) -> Dict[str, Any]: - return dict(self.finalize_executor.get_batch_policy()) - - def get_decode_runtime_counters(self) -> Dict[str, int]: - with self.condition: - return self.decode_runtime_tracker.get_counters() - - def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]: - decode_backlog = self._current_decode_backlog_locked() - finalize_pending = int(self.finalize_executor.get_pending_count()) - prepare_inflight = int(self.prepare_executor.snapshot()["inflight"]) - blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max - blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max - return ( - not blocked_decode and not blocked_finalize, - { - "decode_backlog": decode_backlog, - "finalize_pending": finalize_pending, - "prepare_inflight": prepare_inflight, - "decode_backlog_max": int(self.decode_backlog_max), - "finalize_pending_max": int(self.finalize_pending_max), - }, - ) - - def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]: - start = time.perf_counter() - deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) - last_snapshot: Dict[str, int] = {} - while True: - with self.condition: - allowed, snapshot = self._can_accept_submit_locked() - last_snapshot = snapshot - if allowed: - return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot - if deadline is not None and time.perf_counter() >= deadline: - raise TimeoutError( - "scheduler submit admission timeout " - f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})" - ) - self.condition.wait(timeout=self.micro_batch_wait_s) - - def _admission_snapshot_locked(self) -> Dict[str, int]: - _, snapshot = self._can_accept_submit_locked() - return snapshot - - async def submit_async( - self, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None = None, - done_future: asyncio.Future | None = None, - engine_request_id: str | None = None, - timeout_sec: float | None = None, - skip_capacity_wait: bool = False, - admission_wait_ms_override: float | None = None, - admission_snapshot_override: Dict[str, Any] | None = None, - engine_policy_wait_ms: float = 0.0, - engine_dispatch_wait_ms: float = 0.0, - enqueue_pending: bool = True, - ) -> SchedulerPendingJob: - return await asyncio.to_thread( - self.submit, - state, - speed_factor, - sample_steps, - media_type, - prepare_wall_ms, - prepare_profile_total_ms, - done_loop, - done_future, - engine_request_id, - timeout_sec, - skip_capacity_wait, - admission_wait_ms_override, - admission_snapshot_override, - engine_policy_wait_ms, - engine_dispatch_wait_ms, - enqueue_pending, - ) - - def snapshot(self) -> dict: - with self.condition: - prepare_state = self.prepare_executor.snapshot() - finalize_state = self.finalize_executor.snapshot() - shell_state = self.decode_legacy_shell.snapshot_locked() - decode_runtime_counters = self.decode_runtime_tracker.get_counters() - engine_owned_decode_state = bool(self.engine_decode_control_enabled) - active_batch_summary = shell_state.get("executor_local_active_batch") - executor_local_pending_jobs = int(shell_state.get("executor_local_pending_jobs", 0)) - executor_local_running_requests = int(shell_state.get("executor_local_running_requests", 0)) - executor_local_has_work = bool(shell_state.get("executor_local_has_work", False)) - return { - "pending_jobs": 0 if engine_owned_decode_state else executor_local_pending_jobs, - "running_requests": 0 if engine_owned_decode_state else executor_local_running_requests, - "engine_decode_control_enabled": bool(self.engine_decode_control_enabled), - "legacy_state_owner_mode": not engine_owned_decode_state, - "decode_state_owner": "engine" if engine_owned_decode_state else "worker", - "decode_runtime_has_work": False if engine_owned_decode_state else executor_local_has_work, - "executor_local_pending_jobs": executor_local_pending_jobs, - "executor_local_running_requests": executor_local_running_requests, - "executor_local_has_work": executor_local_has_work, - "decode_runtime_total_cycles": int(decode_runtime_counters.get("total_cycles", 0)), - "decode_runtime_prefill_cycles": int(decode_runtime_counters.get("prefill_cycles", 0)), - "decode_runtime_step_cycles": int(decode_runtime_counters.get("step_cycles", 0)), - "prepare_inflight": prepare_state["inflight"], - "prepare_peak_inflight": prepare_state["peak_inflight"], - "prepare_max_inflight": prepare_state.get("max_inflight", 0), - "prepare_state": dict(prepare_state), - **finalize_state, - "decode_backlog_max": self.decode_backlog_max, - "finalize_pending_max": self.finalize_pending_max, - "active_batch": {} if engine_owned_decode_state else active_batch_summary, - "executor_local_active_batch": active_batch_summary if engine_owned_decode_state else None, - "total_submitted": self.job_registry.submitted_count(), - "total_finished": self.job_registry.finished_count(), - "drained": self.is_drained(), - } - - def is_drained(self) -> bool: - with self.condition: - return ( - self.decode_legacy_shell.is_idle_locked() - and self.job_registry.is_empty() - and self.prepare_executor.is_idle() - and self.finalize_executor.is_idle() - ) - - def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: - deadline = time.perf_counter() + max(0.0, timeout_sec) - while time.perf_counter() < deadline: - if self.is_drained(): - return True - time.sleep(poll_interval_sec) - return self.is_drained() - - def submit( - self, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None = None, - done_future: asyncio.Future | None = None, - engine_request_id: str | None = None, - timeout_sec: float | None = None, - skip_capacity_wait: bool = False, - admission_wait_ms_override: float | None = None, - admission_snapshot_override: Dict[str, Any] | None = None, - engine_policy_wait_ms: float = 0.0, - engine_dispatch_wait_ms: float = 0.0, - enqueue_pending: bool = True, - ) -> SchedulerPendingJob: - if skip_capacity_wait: - with self.condition: - admission_snapshot = ( - dict(admission_snapshot_override) - if admission_snapshot_override is not None - else dict(self._admission_snapshot_locked()) - ) - admission_wait_ms = 0.0 if admission_wait_ms_override is None else float(admission_wait_ms_override) - else: - admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec) - job = SchedulerPendingJob( - request_id=state.request_id, - state=state, - done_event=threading.Event(), - done_loop=done_loop, - done_future=done_future, - enqueue_time=time.perf_counter(), - speed_factor=float(speed_factor), - sample_steps=int(sample_steps), - media_type=media_type, - admission_wait_ms=float(admission_wait_ms), - engine_policy_wait_ms=float(engine_policy_wait_ms), - engine_dispatch_wait_ms=float(engine_dispatch_wait_ms), - prepare_wall_ms=float(prepare_wall_ms), - prepare_profile_total_ms=float(prepare_profile_total_ms), - engine_request_id=engine_request_id or state.request_id, - ) - with self.condition: - self.job_registry.register(job, keep_job=not self.engine_decode_control_enabled) - if enqueue_pending: - self.decode_legacy_shell.enqueue_pending_job_locked(job) - self.condition.notify_all() - if enqueue_pending: - self._notify_decode_runtime_state("submit") - self._runtime_update( - job.engine_request_id, - EngineStatus.QUEUED, - { - "scheduler_request_id": job.request_id, - "decode_admission_wait_ms": float(admission_wait_ms), - "engine_policy_wait_ms": float(engine_policy_wait_ms), - "engine_dispatch_wait_ms": float(engine_dispatch_wait_ms), - "admission_snapshot": dict(admission_snapshot), - }, - ) - return job - - async def prepare_state_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> tuple[T2SRequestState, float, float]: - return await self.prepare_executor.prepare_state_profiled_async(spec, prepare_submit_at) - - async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: - return await self.prepare_executor.prepare_states_batch_async(specs) - - async def prepare_cpu_stage_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> PreparedCpuStage: - return await self.prepare_executor.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) - - async def prepare_gpu_stage_profiled_async( - self, - cpu_stage: PreparedCpuStage, - ) -> tuple[T2SRequestState, float, float]: - return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage) - - def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: - with self.condition: - for job in pending_jobs: - job.first_schedule_time = float(started_at) - self._runtime_update( - job.engine_request_id, - EngineStatus.GPU_PREPARING, - {"scheduler_request_id": job.request_id, "prefill_started_at": float(started_at)}, - ) - - def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - job.prefill_ms += delta_ms - - def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - job.merge_ms += delta_ms - - def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - activate_request_ids: List[str] = [] - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - if job.decode_steps == 0: - activate_request_ids.append(job.engine_request_id) - job.decode_ms += delta_ms - job.decode_steps += 1 - for engine_request_id in activate_request_ids: - self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None) - - def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - job.finalize_wait_ms += float(delta_ms) - - def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: - if not items: - return - enqueued_at = time.perf_counter() - tasks: List[SchedulerFinalizeTask] = [] - with self.condition: - for item in items: - job = self.job_registry.get(item.request_id) - if job is not None: - self._runtime_update( - job.engine_request_id, - EngineStatus.READY_FOR_FINALIZE, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - }, - ) - tasks.append(SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at)) - self.finalize_executor.enqueue_tasks(tasks) - - def begin_finalize_execution(self, task_count: int) -> None: - self.finalize_executor.begin_execution(task_count) - - def end_finalize_execution(self, task_count: int) -> None: - self.finalize_executor.end_execution(task_count) - - def record_external_job_done(self, request_id: str) -> None: - with self.condition: - self.job_registry.mark_finished_and_remove(request_id) - self.condition.notify_all() - - def synthesize_finalize_jobs( - self, - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], - ) -> tuple[float, List[tuple[int, np.ndarray]]]: - return self.finalize_executor.synthesize_finalize_jobs(jobs_and_items) - - def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: - self.completion_bridge.complete_finalize_task( - condition=self.condition, - job_registry=self.job_registry, - job=job, - item=item, - sample_rate=sample_rate, - audio_data=audio_data, - ) - - def _finalize_error(self, request_ids: List[str], error: str) -> None: - self.completion_bridge.fail_jobs( - condition=self.condition, - job_registry=self.job_registry, - request_ids=request_ids, - error=error, - ) - - @staticmethod - def _resolve_done_future(job: SchedulerPendingJob) -> None: - future = job.done_future - if future is None or future.done(): - return - future.set_result(job) - - def _notify_done_future(self, job: SchedulerPendingJob) -> None: - self.completion_bridge.notify_done_future(job) - - def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None: - if request_id is None or self.runtime_callbacks.update is None: - return - self.runtime_callbacks.update(request_id, status, extra) - - def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: - self.completion_bridge.runtime_complete(request_id, extra) - - def _runtime_fail(self, request_id: str | None, error: str) -> None: - self.completion_bridge.runtime_fail(request_id, error) - - def _build_decode_runtime_summary_locked(self, last_event: str) -> Dict[str, Any]: - return self.decode_runtime_tracker.build_runtime_summary_locked( - legacy_shell=self.decode_legacy_shell, - last_event=str(last_event), - ) - - def _notify_decode_runtime_state(self, last_event: str) -> None: - with self.condition: - self.decode_runtime_tracker.notify_runtime_update_locked( - legacy_shell=self.decode_legacy_shell, - last_event=str(last_event), - ) - - def _record_decode_runtime_cycle(self, result: Dict[str, Any]) -> None: - with self.condition: - self.decode_runtime_tracker.record_cycle(result) - - def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.decode_legacy_shell.take_pending_snapshot(wait_for_batch) - - def _take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.decode_legacy_shell.take_pending_snapshot_nonblocking(wait_for_batch) - - def has_decode_runtime_work(self) -> bool: - return self.decode_legacy_shell.has_decode_runtime_work() - - def execute_prefill_merge( - self, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - return self.decode_executor.execute_prefill_merge( - pending_jobs=pending_jobs, - active_batch=active_batch, - mark_prefill_started=self._mark_prefill_started, - add_prefill_time=None if external_bookkeeping else self._add_prefill_time, - add_merge_time=None if external_bookkeeping else self._add_merge_time, - enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, - finalize_error=None if external_bookkeeping else self._finalize_error, - ) - - def execute_decode_step( - self, - active_batch: Optional[T2SActiveBatch], - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - return self.decode_executor.execute_decode_step( - active_batch=active_batch, - add_decode_time=None if external_bookkeeping else self._add_decode_time, - enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, - finalize_error=None if external_bookkeeping else self._finalize_error, - ) - - def execute_decode_cycle( - self, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_executor.execute_decode_cycle( - pending_jobs=pending_jobs, - active_batch=active_batch, - mark_prefill_started=self._mark_prefill_started, - add_prefill_time=None if external_bookkeeping else self._add_prefill_time, - add_merge_time=None if external_bookkeeping else self._add_merge_time, - add_decode_time=None if external_bookkeeping else self._add_decode_time, - enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, - finalize_error=None if external_bookkeeping else self._finalize_error, - ) - self._record_decode_runtime_cycle(result) - return result - - def run_prefill_merge_once_nonblocking( - self, - external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, - external_active_batch: Optional[T2SActiveBatch] = None, - emit_runtime_state: bool = True, - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_legacy_shell.run_prefill_merge_once_nonblocking( - external_pending_jobs=external_pending_jobs, - external_active_batch=external_active_batch, - execute_prefill_merge=lambda batch_jobs, batch_state: self.execute_prefill_merge( - pending_jobs=batch_jobs, - active_batch=batch_state, - external_bookkeeping=external_bookkeeping, - ), - ) - if emit_runtime_state: - self._notify_decode_runtime_state("prefill_merge") - return result - - def run_decode_step_once_nonblocking( - self, - external_active_batch: Optional[T2SActiveBatch] = None, - emit_runtime_state: bool = True, - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_legacy_shell.run_decode_step_once_nonblocking( - external_active_batch=external_active_batch, - execute_decode_step=lambda batch_state: self.execute_decode_step( - active_batch=batch_state, - external_bookkeeping=external_bookkeeping, - ), - ) - if emit_runtime_state: - self._notify_decode_runtime_state("decode_step") - return result - - def run_decode_cycle_nonblocking( - self, - external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, - external_active_batch: Optional[T2SActiveBatch] = None, - emit_runtime_state: bool = True, - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_legacy_shell.run_decode_cycle_nonblocking( - external_pending_jobs=external_pending_jobs, - external_active_batch=external_active_batch, - execute_decode_cycle=lambda batch_jobs, batch_state: self.execute_decode_cycle( - pending_jobs=batch_jobs, - active_batch=batch_state, - external_bookkeeping=external_bookkeeping, - ), - on_cycle_executed=None, - ) - if result.get("executed") and emit_runtime_state: - self._notify_decode_runtime_state("decode_cycle") - return result - - def execute_finalize_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: - if not tasks: - return - try: - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] - with self.condition: - for task in tasks: - job = self.job_registry.get(task.request_id) - if job is None: - continue - jobs_and_items.append((job, task.item)) - if not jobs_and_items: - return - now = time.perf_counter() - for task in tasks: - self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) - for job, item in jobs_and_items: - self._runtime_update( - job.engine_request_id, - EngineStatus.FINALIZING, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - }, - ) - synth_ms, batch_results = self.synthesize_finalize_jobs(jobs_and_items) - with self.condition: - for job, _ in jobs_and_items: - tracked_job = self.job_registry.get(job.request_id) - if tracked_job is not None: - tracked_job.synth_ms += synth_ms - for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): - self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) - except Exception as exc: - self._finalize_error([task.request_id for task in tasks], str(exc)) - finally: - self.finalize_executor.end_execution(len(tasks)) - - def _run_finalize_loop(self) -> None: - while True: - tasks = self.finalize_executor.take_task_batch_blocking() - self.execute_finalize_tasks(tasks) - - def _run_loop(self) -> None: - self.decode_legacy_shell.run_loop( - run_decode_cycle_nonblocking=lambda: self.run_decode_cycle_nonblocking() - ) - - diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_completion.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_completion.py new file mode 100644 index 00000000..da2c057a --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_completion.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import threading +import time +from typing import Any, Callable, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerJobRegistry, SchedulerPendingJob + + +class WorkerCompletionBridge: + def __init__(self, runtime_callbacks: RuntimeStateCallbacks | None = None) -> None: + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(job) + + def notify_done_future(self, job: SchedulerPendingJob) -> None: + if job.done_loop is None or job.done_future is None: + return + try: + job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) + except RuntimeError: + pass + + def runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + if request_id is None or self.runtime_callbacks.complete is None: + return + self.runtime_callbacks.complete(request_id, extra) + + def runtime_fail(self, request_id: str | None, error: str) -> None: + if request_id is None or self.runtime_callbacks.fail is None: + return + self.runtime_callbacks.fail(request_id, error) + + @staticmethod + def build_completed_job_result( + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + finished_at: float | None = None, + ) -> Dict[str, Any]: + finished_at = float(time.perf_counter() if finished_at is None else finished_at) + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + worker_residual_ms = max( + 0.0, + worker_total_ms + - queue_wait_ms + - job.prefill_ms + - job.merge_ms + - job.decode_ms + - job.finalize_wait_ms + - job.synth_ms, + ) + worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + job.result_ready_time = finished_at + prepare_profile = dict(job.state.prepare_profile) + result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "decode_admission_wait_ms": float(job.admission_wait_ms), + "engine_policy_wait_ms": float(job.engine_policy_wait_ms), + "engine_dispatch_wait_ms": float(job.engine_dispatch_wait_ms), + "prepare_ms": job.prepare_wall_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile_total_ms": job.prepare_profile_total_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "merge_ms": job.merge_ms, + "decode_ms": job.decode_ms, + "finalize_wait_ms": job.finalize_wait_ms, + "synth_ms": job.synth_ms, + "worker_residual_ms": worker_residual_ms, + "worker_other_ms": worker_other_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.result = result + return result + + @staticmethod + def build_runtime_complete_payload( + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + ) -> Dict[str, Any]: + return { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "sample_rate": int(sample_rate), + "worker_profile": dict(job.result or {}), + } + + def complete_job( + self, + job: SchedulerPendingJob, + *, + runtime_request_id: str | None, + runtime_extra: Optional[Dict[str, Any]] = None, + remove_job: Callable[[], None] | None = None, + on_job_finished: Callable[[], None] | None = None, + notify_waiters: Callable[[], None] | None = None, + ) -> None: + job.done_event.set() + self.notify_done_future(job) + if remove_job is not None: + remove_job() + if on_job_finished is not None: + on_job_finished() + if notify_waiters is not None: + notify_waiters() + self.runtime_complete(runtime_request_id, runtime_extra) + + def fail_job( + self, + job: SchedulerPendingJob, + *, + error: str, + remove_job: Callable[[], None] | None = None, + on_job_finished: Callable[[], None] | None = None, + notify_waiters: Callable[[], None] | None = None, + ) -> None: + job.error = str(error) + job.done_event.set() + self.notify_done_future(job) + if remove_job is not None: + remove_job() + if on_job_finished is not None: + on_job_finished() + if notify_waiters is not None: + notify_waiters() + self.runtime_fail(job.engine_request_id, str(error)) + + def complete_finalize_task( + self, + *, + condition: threading.Condition, + job_registry: SchedulerJobRegistry, + job: SchedulerPendingJob, + item: T2SFinishedItem, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + runtime_extra: Optional[Dict[str, Any]] = None + with condition: + if job_registry.get(item.request_id) is not job: + return + self.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) + runtime_extra = self.build_runtime_complete_payload(job, item, sample_rate=sample_rate) + self.complete_job( + job, + runtime_request_id=job.engine_request_id, + runtime_extra=runtime_extra, + on_job_finished=lambda: job_registry.mark_finished_and_remove(item.request_id), + notify_waiters=condition.notify_all, + ) + + def fail_jobs( + self, + *, + condition: threading.Condition, + job_registry: SchedulerJobRegistry, + request_ids: List[str], + error: str, + ) -> None: + if not request_ids: + return + with condition: + for request_id in request_ids: + job = job_registry.get(request_id) + if job is None: + continue + self.fail_job( + job, + error=error, + on_job_finished=lambda rid=request_id: job_registry.mark_finished_and_remove(rid), + ) + condition.notify_all() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_decode.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_decode.py new file mode 100644 index 00000000..784f71d0 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_decode.py @@ -0,0 +1,430 @@ +from __future__ import annotations + +import threading +import time +from typing import Any, Callable, Dict, List, Optional + +import torch + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( + T2SActiveBatch, + T2SFinishedItem, + decode_one_step, + merge_active_batches, + run_prefill_active_batch, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerPendingJob + + +class WorkerDecodeExecutor: + def __init__(self, tts: TTS, max_steps: int) -> None: + self.tts = tts + self.max_steps = int(max_steps) + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def execute_prefill_merge( + self, + *, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], + add_prefill_time: Callable[[List[str], float], None] | None, + add_merge_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + if not pending_jobs: + return { + "executed": False, + "active_batch": active_batch, + "pending_jobs": [], + "prefill_elapsed_s": 0.0, + "merge_elapsed_s": 0.0, + "finished_items": [], + "error": None, + "error_request_ids": [], + } + admitted_finished: List[T2SFinishedItem] = [] + prefill_elapsed_s = 0.0 + merge_elapsed_s = 0.0 + error: str | None = None + error_request_ids: List[str] = [] + try: + self._sync_device() + prefill_start = time.perf_counter() + mark_prefill_started(pending_jobs, prefill_start) + admitted_active_batch, admitted_finished = run_prefill_active_batch( + self.tts.t2s_model.model, + [job.state for job in pending_jobs], + max_steps=self.max_steps, + ) + self._sync_device() + prefill_elapsed_s = time.perf_counter() - prefill_start + if add_prefill_time is not None: + add_prefill_time([job.request_id for job in pending_jobs], prefill_elapsed_s) + if enqueue_finished is not None: + enqueue_finished(admitted_finished) + merge_start = time.perf_counter() + active_batch = merge_active_batches( + self.tts.t2s_model.model, + active_batch, + admitted_active_batch, + ) + merge_elapsed_s = time.perf_counter() - merge_start + if add_merge_time is not None: + add_merge_time( + [] if active_batch is None else list(active_batch.request_ids), + merge_elapsed_s, + ) + except Exception as exc: + error = str(exc) + error_request_ids = [job.request_id for job in pending_jobs] + if finalize_error is not None: + finalize_error(error_request_ids, error) + return { + "executed": True, + "active_batch": active_batch, + "pending_jobs": list(pending_jobs), + "prefill_elapsed_s": float(prefill_elapsed_s), + "merge_elapsed_s": float(merge_elapsed_s), + "finished_items": list(admitted_finished), + "error": error, + "error_request_ids": error_request_ids, + } + + def execute_decode_step( + self, + *, + active_batch: Optional[T2SActiveBatch], + add_decode_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + if active_batch is None: + return { + "executed": False, + "active_batch": None, + "request_ids": [], + "decode_elapsed_s": 0.0, + "finished_items": [], + "error": None, + "error_request_ids": [], + } + active_request_ids: List[str] = [] + step_finished: List[T2SFinishedItem] = [] + decode_elapsed_s = 0.0 + error: str | None = None + error_request_ids: List[str] = [] + try: + active_request_ids = [state.request_id for state in active_batch.states] + self._sync_device() + decode_start = time.perf_counter() + active_batch, step_finished = decode_one_step( + self.tts.t2s_model.model, + active_batch, + max_steps=self.max_steps, + ) + self._sync_device() + decode_elapsed_s = time.perf_counter() - decode_start + if add_decode_time is not None: + add_decode_time(active_request_ids, decode_elapsed_s) + if enqueue_finished is not None: + enqueue_finished(step_finished) + except Exception as exc: + error = str(exc) + error_request_ids = list(active_request_ids) + if finalize_error is not None: + finalize_error(error_request_ids, error) + active_batch = None + return { + "executed": True, + "active_batch": active_batch, + "request_ids": active_request_ids, + "decode_elapsed_s": float(decode_elapsed_s), + "finished_items": list(step_finished), + "error": error, + "error_request_ids": error_request_ids, + } + + def execute_decode_cycle( + self, + *, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], + add_prefill_time: Callable[[List[str], float], None] | None, + add_merge_time: Callable[[List[str], float], None] | None, + add_decode_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + result = { + "executed": False, + "prefill_merge_executed": False, + "decode_step_executed": False, + "active_batch": active_batch, + "prefill_phase": {}, + "decode_phase": {}, + } + prefill_phase = self.execute_prefill_merge( + pending_jobs=list(pending_jobs), + active_batch=result["active_batch"], + mark_prefill_started=mark_prefill_started, + add_prefill_time=add_prefill_time, + add_merge_time=add_merge_time, + enqueue_finished=enqueue_finished, + finalize_error=finalize_error, + ) + prefill_executed = bool(prefill_phase.get("executed", False)) + result["prefill_phase"] = prefill_phase + result["active_batch"] = prefill_phase.get("active_batch") + if prefill_executed: + result["executed"] = True + result["prefill_merge_executed"] = True + decode_phase = self.execute_decode_step( + active_batch=result["active_batch"], + add_decode_time=add_decode_time, + enqueue_finished=enqueue_finished, + finalize_error=finalize_error, + ) + decode_executed = bool(decode_phase.get("executed", False)) + result["decode_phase"] = decode_phase + result["active_batch"] = decode_phase.get("active_batch") + if decode_executed: + result["executed"] = True + result["decode_step_executed"] = True + return result + + +class WorkerDecodeLegacyShell: + def __init__(self, condition: threading.Condition, micro_batch_wait_s: float) -> None: + self.condition = condition + self.micro_batch_wait_s = float(micro_batch_wait_s) + self.pending_jobs: List[SchedulerPendingJob] = [] + self.active_batch: T2SActiveBatch | None = None + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any] | None: + if active_batch is None: + return None + return { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": ( + int(active_batch.step_indices.max().item()) + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0 + else 0 + ), + } + + def current_backlog_locked(self) -> int: + running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids) + return int(len(self.pending_jobs) + running_requests) + + def enqueue_pending_job_locked(self, job: SchedulerPendingJob) -> None: + self.pending_jobs.append(job) + + def snapshot_locked(self) -> Dict[str, Any]: + active_batch_summary = self._summarize_active_batch(self.active_batch) + executor_local_pending_jobs = int(len(self.pending_jobs)) + executor_local_running_requests = 0 if self.active_batch is None else int(len(self.active_batch.request_ids)) + executor_local_has_work = bool(self.pending_jobs or self.active_batch is not None) + return { + "executor_local_pending_jobs": executor_local_pending_jobs, + "executor_local_running_requests": executor_local_running_requests, + "executor_local_has_work": executor_local_has_work, + "executor_local_active_batch": active_batch_summary, + } + + def is_idle_locked(self) -> bool: + return self.active_batch is None and not self.pending_jobs + + def take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs and self.active_batch is None: + self.condition.wait(timeout=self.micro_batch_wait_s) + elif wait_for_batch and self.pending_jobs: + self.condition.wait(timeout=self.micro_batch_wait_s) + if not self.pending_jobs: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs: + return [] + if wait_for_batch: + oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) + if (time.perf_counter() - oldest_enqueue_time) < self.micro_batch_wait_s: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def has_decode_runtime_work(self) -> bool: + with self.condition: + return bool(self.pending_jobs or self.active_batch is not None) + + def build_runtime_summary_locked(self, *, total_cycles: int, prefill_cycles: int, step_cycles: int, last_event: str) -> Dict[str, Any]: + active_request_ids = [] if self.active_batch is None else list(self.active_batch.request_ids) + decode_step_index_max = 0 + prefill_done = False + if self.active_batch is not None: + prefill_done = bool(self.active_batch.prefill_done) + if self.active_batch.step_indices is not None and self.active_batch.step_indices.numel() > 0: + decode_step_index_max = int(self.active_batch.step_indices.max().item()) + return { + "pending_jobs": int(len(self.pending_jobs)), + "active_request_count": int(len(active_request_ids)), + "active_request_ids": active_request_ids[:32], + "prefill_done": bool(prefill_done), + "decode_step_index_max": int(decode_step_index_max), + "total_cycles": int(total_cycles), + "prefill_cycles": int(prefill_cycles), + "step_cycles": int(step_cycles), + "has_work": bool(self.pending_jobs or self.active_batch is not None), + "last_event": str(last_event), + "updated_at": float(time.perf_counter()), + } + + def run_prefill_merge_once_nonblocking( + self, + *, + external_pending_jobs: Optional[List[SchedulerPendingJob]], + external_active_batch: Optional[T2SActiveBatch], + execute_prefill_merge: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], + ) -> Dict[str, Any]: + pending_jobs = ( + list(external_pending_jobs) + if external_pending_jobs is not None + else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) + ) + active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch + result = execute_prefill_merge(pending_jobs, active_batch) + if external_pending_jobs is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + return result + + def run_decode_step_once_nonblocking( + self, + *, + external_active_batch: Optional[T2SActiveBatch], + execute_decode_step: Callable[[Optional[T2SActiveBatch]], Dict[str, Any]], + ) -> Dict[str, Any]: + active_batch = self.active_batch if external_active_batch is None else external_active_batch + result = execute_decode_step(active_batch) + if external_active_batch is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + return result + + def run_decode_cycle_nonblocking( + self, + *, + external_pending_jobs: Optional[List[SchedulerPendingJob]], + external_active_batch: Optional[T2SActiveBatch], + execute_decode_cycle: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], + on_cycle_executed: Callable[[Dict[str, Any]], None] | None, + ) -> Dict[str, Any]: + pending_jobs = ( + list(external_pending_jobs) + if external_pending_jobs is not None + else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) + ) + active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch + result = execute_decode_cycle(pending_jobs, active_batch) + if external_pending_jobs is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + if result.get("executed") and on_cycle_executed is not None: + on_cycle_executed(result) + return result + + def run_loop( + self, + *, + run_decode_cycle_nonblocking: Callable[[], Dict[str, Any]], + ) -> None: + while True: + executed = run_decode_cycle_nonblocking() + if executed.get("executed"): + continue + wait_for_batch = self.active_batch is None + pending_jobs = self.take_pending_snapshot(wait_for_batch=wait_for_batch) + if pending_jobs: + with self.condition: + self.pending_jobs = pending_jobs + self.pending_jobs + self.condition.notify_all() + continue + time.sleep(self.micro_batch_wait_s) + + +class WorkerDecodeRuntimeTracker: + def __init__( + self, + runtime_callbacks: RuntimeStateCallbacks | None = None, + ) -> None: + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + self.total_cycles = 0 + self.prefill_cycles = 0 + self.step_cycles = 0 + + def get_counters(self) -> Dict[str, int]: + return { + "total_cycles": int(self.total_cycles), + "prefill_cycles": int(self.prefill_cycles), + "step_cycles": int(self.step_cycles), + } + + def record_cycle(self, result: Dict[str, Any]) -> None: + if not bool(result.get("executed")): + return + self.total_cycles += 1 + if bool(result.get("prefill_merge_executed")): + self.prefill_cycles += 1 + if bool(result.get("decode_step_executed")): + self.step_cycles += 1 + + def build_runtime_summary_locked( + self, + *, + legacy_shell: WorkerDecodeLegacyShell, + last_event: str, + ) -> Dict[str, Any]: + return legacy_shell.build_runtime_summary_locked( + total_cycles=int(self.total_cycles), + prefill_cycles=int(self.prefill_cycles), + step_cycles=int(self.step_cycles), + last_event=str(last_event), + ) + + def notify_runtime_update_locked( + self, + *, + legacy_shell: WorkerDecodeLegacyShell, + last_event: str, + ) -> None: + if self.runtime_callbacks.decode_runtime_update is None: + return + snapshot = self.build_runtime_summary_locked( + legacy_shell=legacy_shell, + last_event=last_event, + ) + self.runtime_callbacks.decode_runtime_update(snapshot) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_execution.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_execution.py new file mode 100644 index 00000000..465f7a2c --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_execution.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import time +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch, T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob + + +class WorkerExecutionMixin: + def execute_prefill_merge( + self, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + return self.decode_executor.execute_prefill_merge( + pending_jobs=pending_jobs, + active_batch=active_batch, + mark_prefill_started=self._mark_prefill_started, + add_prefill_time=None if external_bookkeeping else self._add_prefill_time, + add_merge_time=None if external_bookkeeping else self._add_merge_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + + def execute_decode_step( + self, + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + return self.decode_executor.execute_decode_step( + active_batch=active_batch, + add_decode_time=None if external_bookkeeping else self._add_decode_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + + def execute_decode_cycle( + self, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_executor.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=active_batch, + mark_prefill_started=self._mark_prefill_started, + add_prefill_time=None if external_bookkeeping else self._add_prefill_time, + add_merge_time=None if external_bookkeeping else self._add_merge_time, + add_decode_time=None if external_bookkeeping else self._add_decode_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + self._record_decode_runtime_cycle(result) + return result + + def run_prefill_merge_once_nonblocking( + self, + external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_prefill_merge_once_nonblocking( + external_pending_jobs=external_pending_jobs, + external_active_batch=external_active_batch, + execute_prefill_merge=lambda batch_jobs, batch_state: self.execute_prefill_merge( + pending_jobs=batch_jobs, + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + ) + if emit_runtime_state: + self._notify_decode_runtime_state("prefill_merge") + return result + + def run_decode_step_once_nonblocking( + self, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_decode_step_once_nonblocking( + external_active_batch=external_active_batch, + execute_decode_step=lambda batch_state: self.execute_decode_step( + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + ) + if emit_runtime_state: + self._notify_decode_runtime_state("decode_step") + return result + + def run_decode_cycle_nonblocking( + self, + external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_decode_cycle_nonblocking( + external_pending_jobs=external_pending_jobs, + external_active_batch=external_active_batch, + execute_decode_cycle=lambda batch_jobs, batch_state: self.execute_decode_cycle( + pending_jobs=batch_jobs, + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + on_cycle_executed=None, + ) + if result.get("executed") and emit_runtime_state: + self._notify_decode_runtime_state("decode_cycle") + return result + + def execute_finalize_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + with self.condition: + for task in tasks: + job = self.job_registry.get(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return + now = time.perf_counter() + for task in tasks: + self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) + for job, item in jobs_and_items: + self._runtime_update( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.synthesize_finalize_jobs(jobs_and_items) + with self.condition: + for job, _ in jobs_and_items: + tracked_job = self.job_registry.get(job.request_id) + if tracked_job is not None: + tracked_job.synth_ms += synth_ms + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self._finalize_error([task.request_id for task in tasks], str(exc)) + finally: + self.finalize_executor.end_execution(len(tasks)) + + def _run_finalize_loop(self) -> None: + while True: + tasks = self.finalize_executor.take_task_batch_blocking() + self.execute_finalize_tasks(tasks) + + def _run_loop(self) -> None: + self.decode_legacy_shell.run_loop( + run_decode_cycle_nonblocking=lambda: self.run_decode_cycle_nonblocking() + ) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py new file mode 100644 index 00000000..4f5833fd --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py @@ -0,0 +1,234 @@ +from __future__ import annotations + +import os +import threading +import time +from collections import deque +from typing import Any, Callable, Deque, Dict, List + +import numpy as np +import torch + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import SchedulerFinalizeTask, SchedulerPendingJob + + +class WorkerFinalizeExecutor: + def __init__( + self, + tts: TTS, + on_state_change: Callable[[], None] | None = None, + external_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, + ) -> None: + self.tts = tts + self.on_state_change = on_state_change + self.external_submit = external_submit + self.condition = threading.Condition() + self.pending_tasks: Deque[SchedulerFinalizeTask] = deque() + self.pending_peak = 0 + self.inflight = 0 + self.inflight_peak = 0 + self.worker_count = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) + self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() + self.batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) + self.batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) + + def _notify_state_change(self) -> None: + if self.on_state_change is None: + return + try: + self.on_state_change() + except Exception: + pass + + def get_worker_count(self) -> int: + return int(self.worker_count) + + def get_batch_policy(self) -> Dict[str, Any]: + return { + "finalize_mode": str(self.finalize_mode), + "finalize_batch_max_items": int(self.batch_max_items), + "finalize_batch_wait_s": float(self.batch_wait_s), + } + + def get_pending_count(self) -> int: + with self.condition: + return int(len(self.pending_tasks)) + + def snapshot(self) -> Dict[str, Any]: + with self.condition: + return { + "finalize_pending": int(len(self.pending_tasks)), + "finalize_pending_peak": int(self.pending_peak), + "finalize_inflight": int(self.inflight), + "finalize_inflight_peak": int(self.inflight_peak), + "finalize_workers": int(self.worker_count), + "finalize_mode": str(self.finalize_mode), + "finalize_batch_max_items": int(self.batch_max_items), + "finalize_batch_wait_ms": float(self.batch_wait_s * 1000.0), + } + + def is_idle(self) -> bool: + with self.condition: + return self.inflight <= 0 and not self.pending_tasks + + def enqueue_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + if self.external_submit is not None: + self.external_submit(tasks) + self._notify_state_change() + return + with self.condition: + for task in tasks: + self.pending_tasks.append(task) + self.pending_peak = max(self.pending_peak, len(self.pending_tasks)) + self.condition.notify_all() + self._notify_state_change() + + def begin_execution(self, task_count: int) -> None: + if task_count <= 0: + return + with self.condition: + self.inflight += int(task_count) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self.condition.notify_all() + self._notify_state_change() + + def end_execution(self, task_count: int) -> None: + with self.condition: + self.inflight = max(0, self.inflight - int(task_count)) + self.condition.notify_all() + self._notify_state_change() + + def take_task_batch_blocking(self) -> List[SchedulerFinalizeTask]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + selected_tasks = [self.pending_tasks.popleft()] + if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: + self.inflight += len(selected_tasks) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self._notify_state_change() + return selected_tasks + batch_deadline = time.perf_counter() + self.batch_wait_s + while len(selected_tasks) < self.batch_max_items: + if not self.pending_tasks: + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.condition.wait(timeout=remaining) + continue + first_task = selected_tasks[0] + matched_index = None + for index, task in enumerate(self.pending_tasks): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is not None: + selected_tasks.append(self.pending_tasks[matched_index]) + del self.pending_tasks[matched_index] + continue + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.condition.wait(timeout=remaining) + self.inflight += len(selected_tasks) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self._notify_state_change() + return selected_tasks + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: + audio_fragment = self.tts.synthesize_audio_request_local( + semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), + phones=job.state.phones.detach().clone().unsqueeze(0), + prompt_semantic=job.state.prompt_semantic.detach().clone(), + prompt_phones=job.state.prompt_phones.detach().clone(), + refer_spec=( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ), + raw_audio=job.state.raw_audio.detach().clone(), + raw_sr=int(job.state.raw_sr), + speed=float(job.speed_factor), + sample_steps=int(job.sample_steps), + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + return self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + + def _synthesize_finished_audio_batch( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> List[tuple[int, np.ndarray]]: + semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] + phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] + refer_specs = [] + speeds = [] + sample_steps_list = [] + for job, _ in jobs_and_items: + refer_specs.append( + ( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ) + ) + speeds.append(float(job.speed_factor)) + sample_steps_list.append(int(job.sample_steps)) + audio_fragments = self.tts.synthesize_audio_requests_local_batched( + semantic_tokens_list=semantic_tokens_list, + phones_list=phones_list, + refer_specs=refer_specs, + speeds=speeds, + sample_steps_list=sample_steps_list, + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + results: List[tuple[int, np.ndarray]] = [] + for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): + results.append( + self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + ) + return results + + def synthesize_finalize_jobs( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> tuple[float, List[tuple[int, np.ndarray]]]: + if not jobs_and_items: + return 0.0, [] + self._sync_device() + synth_start = time.perf_counter() + if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: + job, item = jobs_and_items[0] + batch_results = [self._synthesize_finished_audio(job, item)] + else: + batch_results = self._synthesize_finished_audio_batch(jobs_and_items) + self._sync_device() + synth_ms = (time.perf_counter() - synth_start) * 1000.0 + return float(synth_ms), batch_results diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py new file mode 100644 index 00000000..28da24ee --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Callable, Dict, List + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState + + +class WorkerPrepareExecutor: + def __init__( + self, + tts: TTS, + on_state_change: Callable[[], None] | None = None, + ) -> None: + self.coordinator = PrepareCoordinator(tts) + self.on_state_change = on_state_change + + def _notify_state_change(self) -> None: + if self.on_state_change is None: + return + try: + self.on_state_change() + except Exception: + pass + + def snapshot(self) -> Dict[str, int]: + return dict(self.coordinator.snapshot()) + + def get_max_inflight(self) -> int: + return int(self.coordinator.snapshot().get("max_inflight", 0)) + + def is_idle(self) -> bool: + return int(self.coordinator.snapshot().get("inflight", 0)) <= 0 + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + try: + return await self.coordinator.prepare_state_profiled_async(spec, prepare_submit_at) + finally: + self._notify_state_change() + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + results = await asyncio.gather( + *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] + ) + return [state for state, _, _ in results] + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + try: + return await self.coordinator.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + finally: + self._notify_state_change() + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + try: + return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) + finally: + self._notify_state_change() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_runtime.py new file mode 100644 index 00000000..de12f5e1 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_runtime.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import threading +import time +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob + + +class WorkerRuntimeBookkeepingMixin: + def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: + with self.condition: + for job in pending_jobs: + job.first_schedule_time = float(started_at) + self._runtime_update( + job.engine_request_id, + EngineStatus.GPU_PREPARING, + {"scheduler_request_id": job.request_id, "prefill_started_at": float(started_at)}, + ) + + def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.prefill_ms += delta_ms + + def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + activate_request_ids: List[str] = [] + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) + job.decode_ms += delta_ms + job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None) + + def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.finalize_wait_ms += float(delta_ms) + + def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + tasks: List[SchedulerFinalizeTask] = [] + with self.condition: + for item in items: + job = self.job_registry.get(item.request_id) + if job is not None: + self._runtime_update( + job.engine_request_id, + EngineStatus.READY_FOR_FINALIZE, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + }, + ) + tasks.append(SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at)) + self.finalize_executor.enqueue_tasks(tasks) + + def begin_finalize_execution(self, task_count: int) -> None: + self.finalize_executor.begin_execution(task_count) + + def end_finalize_execution(self, task_count: int) -> None: + self.finalize_executor.end_execution(task_count) + + def record_external_job_done(self, request_id: str) -> None: + with self.condition: + self.job_registry.mark_finished_and_remove(request_id) + self.condition.notify_all() + + def synthesize_finalize_jobs( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> tuple[float, List[tuple[int, np.ndarray]]]: + return self.finalize_executor.synthesize_finalize_jobs(jobs_and_items) + + def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: + self.completion_bridge.complete_finalize_task( + condition=self.condition, + job_registry=self.job_registry, + job=job, + item=item, + sample_rate=sample_rate, + audio_data=audio_data, + ) + + def _finalize_error(self, request_ids: List[str], error: str) -> None: + self.completion_bridge.fail_jobs( + condition=self.condition, + job_registry=self.job_registry, + request_ids=request_ids, + error=error, + ) + + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(job) + + def _notify_done_future(self, job: SchedulerPendingJob) -> None: + self.completion_bridge.notify_done_future(job) + + def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + if request_id is None or self.runtime_callbacks.update is None: + return + self.runtime_callbacks.update(request_id, status, extra) + + def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + self.completion_bridge.runtime_complete(request_id, extra) + + def _runtime_fail(self, request_id: str | None, error: str) -> None: + self.completion_bridge.runtime_fail(request_id, error) + + def _build_decode_runtime_summary_locked(self, last_event: str) -> Dict[str, Any]: + return self.decode_runtime_tracker.build_runtime_summary_locked( + legacy_shell=self.decode_legacy_shell, + last_event=str(last_event), + ) + + def _notify_decode_runtime_state(self, last_event: str) -> None: + with self.condition: + self.decode_runtime_tracker.notify_runtime_update_locked( + legacy_shell=self.decode_legacy_shell, + last_event=str(last_event), + ) + + def _record_decode_runtime_cycle(self, result: Dict[str, Any]) -> None: + with self.condition: + self.decode_runtime_tracker.record_cycle(result) + + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.decode_legacy_shell.take_pending_snapshot(wait_for_batch) + + def _take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.decode_legacy_shell.take_pending_snapshot_nonblocking(wait_for_batch) + + def has_decode_runtime_work(self) -> bool: + return self.decode_legacy_shell.has_decode_runtime_work() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py new file mode 100644 index 00000000..1e67f8d3 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +import asyncio +import threading +import time +from typing import Any, Dict, List + +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PreparedCpuStage +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerPendingJob + + +class WorkerSubmitLifecycleMixin: + def _current_decode_backlog_locked(self) -> int: + return self.decode_legacy_shell.current_backlog_locked() + + def get_micro_batch_wait_s(self) -> float: + return float(self.micro_batch_wait_s) + + def is_engine_decode_control_enabled(self) -> bool: + return bool(self.engine_decode_control_enabled) + + def get_prepare_max_inflight(self) -> int: + return int(self.prepare_executor.get_max_inflight()) + + def get_capacity_limits(self) -> Dict[str, int]: + return { + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + } + + def get_finalize_batch_policy(self) -> Dict[str, Any]: + return dict(self.finalize_executor.get_batch_policy()) + + def get_decode_runtime_counters(self) -> Dict[str, int]: + with self.condition: + return self.decode_runtime_tracker.get_counters() + + def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]: + decode_backlog = self._current_decode_backlog_locked() + finalize_pending = int(self.finalize_executor.get_pending_count()) + prepare_inflight = int(self.prepare_executor.snapshot()["inflight"]) + blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max + blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max + return ( + not blocked_decode and not blocked_finalize, + { + "decode_backlog": decode_backlog, + "finalize_pending": finalize_pending, + "prepare_inflight": prepare_inflight, + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + }, + ) + + def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]: + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + while True: + with self.condition: + allowed, snapshot = self._can_accept_submit_locked() + if allowed: + return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot + if deadline is not None and time.perf_counter() >= deadline: + raise TimeoutError( + "scheduler submit admission timeout " + f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})" + ) + self.condition.wait(timeout=self.micro_batch_wait_s) + + def _admission_snapshot_locked(self) -> Dict[str, int]: + _, snapshot = self._can_accept_submit_locked() + return snapshot + + async def submit_async( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + skip_capacity_wait: bool = False, + admission_wait_ms_override: float | None = None, + admission_snapshot_override: Dict[str, Any] | None = None, + engine_policy_wait_ms: float = 0.0, + engine_dispatch_wait_ms: float = 0.0, + enqueue_pending: bool = True, + ) -> SchedulerPendingJob: + return await asyncio.to_thread( + self.submit, + state, + speed_factor, + sample_steps, + media_type, + prepare_wall_ms, + prepare_profile_total_ms, + done_loop, + done_future, + engine_request_id, + timeout_sec, + skip_capacity_wait, + admission_wait_ms_override, + admission_snapshot_override, + engine_policy_wait_ms, + engine_dispatch_wait_ms, + enqueue_pending, + ) + + def snapshot(self) -> dict: + with self.condition: + prepare_state = self.prepare_executor.snapshot() + finalize_state = self.finalize_executor.snapshot() + shell_state = self.decode_legacy_shell.snapshot_locked() + decode_runtime_counters = self.decode_runtime_tracker.get_counters() + engine_owned_decode_state = bool(self.engine_decode_control_enabled) + active_batch_summary = shell_state.get("executor_local_active_batch") + executor_local_pending_jobs = int(shell_state.get("executor_local_pending_jobs", 0)) + executor_local_running_requests = int(shell_state.get("executor_local_running_requests", 0)) + executor_local_has_work = bool(shell_state.get("executor_local_has_work", False)) + return { + "pending_jobs": 0 if engine_owned_decode_state else executor_local_pending_jobs, + "running_requests": 0 if engine_owned_decode_state else executor_local_running_requests, + "engine_decode_control_enabled": bool(self.engine_decode_control_enabled), + "legacy_state_owner_mode": not engine_owned_decode_state, + "decode_state_owner": "engine" if engine_owned_decode_state else "worker", + "decode_runtime_has_work": False if engine_owned_decode_state else executor_local_has_work, + "executor_local_pending_jobs": executor_local_pending_jobs, + "executor_local_running_requests": executor_local_running_requests, + "executor_local_has_work": executor_local_has_work, + "decode_runtime_total_cycles": int(decode_runtime_counters.get("total_cycles", 0)), + "decode_runtime_prefill_cycles": int(decode_runtime_counters.get("prefill_cycles", 0)), + "decode_runtime_step_cycles": int(decode_runtime_counters.get("step_cycles", 0)), + "prepare_inflight": prepare_state["inflight"], + "prepare_peak_inflight": prepare_state["peak_inflight"], + "prepare_max_inflight": prepare_state.get("max_inflight", 0), + "prepare_state": dict(prepare_state), + **finalize_state, + "decode_backlog_max": self.decode_backlog_max, + "finalize_pending_max": self.finalize_pending_max, + "active_batch": {} if engine_owned_decode_state else active_batch_summary, + "executor_local_active_batch": active_batch_summary if engine_owned_decode_state else None, + "total_submitted": self.job_registry.submitted_count(), + "total_finished": self.job_registry.finished_count(), + "drained": self.is_drained(), + } + + def is_drained(self) -> bool: + with self.condition: + return ( + self.decode_legacy_shell.is_idle_locked() + and self.job_registry.is_empty() + and self.prepare_executor.is_idle() + and self.finalize_executor.is_idle() + ) + + def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: + deadline = time.perf_counter() + max(0.0, timeout_sec) + while time.perf_counter() < deadline: + if self.is_drained(): + return True + time.sleep(poll_interval_sec) + return self.is_drained() + + def submit( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + skip_capacity_wait: bool = False, + admission_wait_ms_override: float | None = None, + admission_snapshot_override: Dict[str, Any] | None = None, + engine_policy_wait_ms: float = 0.0, + engine_dispatch_wait_ms: float = 0.0, + enqueue_pending: bool = True, + ) -> SchedulerPendingJob: + if skip_capacity_wait: + with self.condition: + admission_snapshot = ( + dict(admission_snapshot_override) + if admission_snapshot_override is not None + else dict(self._admission_snapshot_locked()) + ) + admission_wait_ms = 0.0 if admission_wait_ms_override is None else float(admission_wait_ms_override) + else: + admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec) + job = SchedulerPendingJob( + request_id=state.request_id, + state=state, + done_event=threading.Event(), + done_loop=done_loop, + done_future=done_future, + enqueue_time=time.perf_counter(), + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + admission_wait_ms=float(admission_wait_ms), + engine_policy_wait_ms=float(engine_policy_wait_ms), + engine_dispatch_wait_ms=float(engine_dispatch_wait_ms), + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + engine_request_id=engine_request_id or state.request_id, + ) + with self.condition: + self.job_registry.register(job, keep_job=not self.engine_decode_control_enabled) + if enqueue_pending: + self.decode_legacy_shell.enqueue_pending_job_locked(job) + self.condition.notify_all() + if enqueue_pending: + self._notify_decode_runtime_state("submit") + self._runtime_update( + job.engine_request_id, + EngineStatus.QUEUED, + { + "scheduler_request_id": job.request_id, + "decode_admission_wait_ms": float(admission_wait_ms), + "engine_policy_wait_ms": float(engine_policy_wait_ms), + "engine_dispatch_wait_ms": float(engine_dispatch_wait_ms), + "admission_snapshot": dict(admission_snapshot), + }, + ) + return job + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_executor.prepare_state_profiled_async(spec, prepare_submit_at) + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + return await self.prepare_executor.prepare_states_batch_async(specs) + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + return await self.prepare_executor.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage)