diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py index 1fdf95c5..306b1b45 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py @@ -33,6 +33,20 @@ class ProfiledResult: return max(0.0, (self.finished_at - self.started_at) * 1000.0) +@dataclass +class PreparedCpuStage: + spec: SchedulerRequestSpec + prepare_submit_at: float + prepare_start: float + prompt_text: str + text: str + prepare_admission_wait_ms: float + current_inflight: int + peak_inflight: int + prompt_cpu_profiled: ProfiledResult + target_cpu_profiled: ProfiledResult + + class PrepareCoordinator: def __init__(self, tts: Any): self.tts = tts @@ -216,11 +230,16 @@ class PrepareCoordinator: async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult: return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path) - async def prepare_state_profiled_async( + def _release_split_stage_slot(self) -> None: + self._mark_leave() + if self._inflight_semaphore is not None: + self._inflight_semaphore.release() + + async def prepare_cpu_stage_profiled_async( self, spec: SchedulerRequestSpec, prepare_submit_at: float, - ) -> tuple[T2SRequestState, float, float]: + ) -> PreparedCpuStage: admission_start = time.perf_counter() if self._inflight_semaphore is not None: await self._inflight_semaphore.acquire() @@ -230,17 +249,38 @@ class PrepareCoordinator: prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) text = spec.text.strip("\n") try: - text_pair_start = time.perf_counter() prompt_cpu_task = asyncio.create_task(self._run_text_cpu_stage(prompt_text, spec.prompt_lang)) target_cpu_task = asyncio.create_task(self._run_text_cpu_stage(text, spec.text_lang)) - ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(spec.ref_audio_path))) prompt_cpu_profiled, target_cpu_profiled = await asyncio.gather(prompt_cpu_task, target_cpu_task) + return PreparedCpuStage( + spec=spec, + prepare_submit_at=float(prepare_submit_at), + prepare_start=float(prepare_start), + prompt_text=prompt_text, + text=text, + prepare_admission_wait_ms=float(prepare_admission_wait_ms), + current_inflight=int(current_inflight), + peak_inflight=int(peak_inflight), + prompt_cpu_profiled=prompt_cpu_profiled, + target_cpu_profiled=target_cpu_profiled, + ) + except Exception: + self._release_split_stage_slot() + raise + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + try: + text_pair_start = time.perf_counter() + ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(cpu_stage.spec.ref_audio_path))) text_feature_pair_task = asyncio.create_task( self._run_text_feature_pair_stage( - prompt_cpu_profiled.result, - target_cpu_profiled.result, - prompt_cpu_profiled.run_ms, - target_cpu_profiled.run_ms, + cpu_stage.prompt_cpu_profiled.result, + cpu_stage.target_cpu_profiled.result, + cpu_stage.prompt_cpu_profiled.run_ms, + cpu_stage.target_cpu_profiled.run_ms, ) ) (prompt_feature_profiled, target_feature_profiled), ref_audio_profiled = await asyncio.gather( @@ -250,18 +290,18 @@ class PrepareCoordinator: text_pair_end = time.perf_counter() state = build_request_state_from_parts( tts=self.tts, - spec=spec, - prompt_text=prompt_text, - text=text, + spec=cpu_stage.spec, + prompt_text=cpu_stage.prompt_text, + text=cpu_stage.text, prompt_result=prompt_feature_profiled.result, target_result=target_feature_profiled.result, ref_audio_bundle=ref_audio_profiled.result, - prepare_start=prepare_start, - prepare_sync_start=prepare_start, + prepare_start=cpu_stage.prepare_start, + prepare_sync_start=cpu_stage.prepare_start, profile_overrides={ - "executor_queue_ms": max(0.0, (prepare_start - prepare_submit_at) * 1000.0), - "prepare_admission_wait_ms": prepare_admission_wait_ms, - "executor_run_wall_ms": max(0.0, (time.perf_counter() - prepare_start) * 1000.0), + "executor_queue_ms": max(0.0, (cpu_stage.prepare_start - cpu_stage.prepare_submit_at) * 1000.0), + "prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms, + "executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0), "text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0), "prompt_text_parallel_future_wait_ms": 0.0, "prompt_text_parallel_future_executor_queue_ms": 0.0, @@ -269,26 +309,32 @@ class PrepareCoordinator: "prompt_text_parallel_future_finish_after_submit_ms": 0.0, "prompt_text_parallel_future_queue_tail_after_target_ms": 0.0, "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, - "prompt_text_cpu_queue_ms": prompt_cpu_profiled.queue_ms, - "prompt_text_cpu_run_ms": prompt_cpu_profiled.run_ms, + "prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms, + "prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms, "prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms, "prompt_text_feature_run_ms": prompt_feature_profiled.run_ms, - "text_cpu_queue_ms": target_cpu_profiled.queue_ms, - "text_cpu_run_ms": target_cpu_profiled.run_ms, + "text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms, + "text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms, "text_feature_queue_ms": target_feature_profiled.queue_ms, "text_feature_run_ms": target_feature_profiled.run_ms, "ref_audio_task_queue_ms": ref_audio_profiled.queue_ms, "ref_audio_task_run_ms": ref_audio_profiled.run_ms, - "worker_prepare_inflight_on_enter": float(current_inflight), - "worker_prepare_peak_inflight": float(peak_inflight), + "worker_prepare_inflight_on_enter": float(cpu_stage.current_inflight), + "worker_prepare_peak_inflight": float(cpu_stage.peak_inflight), }, ) prepare_exec_finished_at = time.perf_counter() state.prepare_profile["executor_run_wall_ms"] = max( - 0.0, (prepare_exec_finished_at - prepare_start) * 1000.0 + 0.0, (prepare_exec_finished_at - cpu_stage.prepare_start) * 1000.0 ) - return state, prepare_start, prepare_exec_finished_at + return state, cpu_stage.prepare_start, prepare_exec_finished_at finally: - self._mark_leave() - if self._inflight_semaphore is not None: - self._inflight_semaphore.release() + self._release_split_stage_slot() + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + cpu_stage = await self.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + return await self.prepare_gpu_stage_profiled_async(cpu_stage) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine.py b/GPT_SoVITS/TTS_infer_pack/unified_engine.py index aed7b146..9b56199a 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine.py @@ -20,7 +20,7 @@ import soundfile as sf import torch from GPT_SoVITS.TTS_infer_pack.TTS import TTS -from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( SchedulerRequestSpec, T2SActiveBatch, @@ -235,6 +235,46 @@ class SchedulerSubmitExecution: headers: Dict[str, str] +@dataclass +class EnginePolicyConfig: + enabled: bool = True + poll_wait_ms: float = 5.0 + decode_backlog_soft_max: int = 0 + finalize_pending_soft_max: int = 0 + prepare_inflight_soft_max: int = 0 + active_decode_soft_max: int = 0 + ready_for_prefill_soft_max: int = 0 + active_request_soft_max: int = 0 + + def to_dict(self) -> Dict[str, Any]: + return { + "enabled": bool(self.enabled), + "poll_wait_ms": float(self.poll_wait_ms), + "decode_backlog_soft_max": int(self.decode_backlog_soft_max), + "finalize_pending_soft_max": int(self.finalize_pending_soft_max), + "prepare_inflight_soft_max": int(self.prepare_inflight_soft_max), + "active_decode_soft_max": int(self.active_decode_soft_max), + "ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max), + "active_request_soft_max": int(self.active_request_soft_max), + } + + +@dataclass +class EngineArbiterConfig: + poll_wait_ms: float = 5.0 + decode_burst: int = 4 + prepare_aging_ms: float = 10.0 + finalize_aging_ms: float = 10.0 + + def to_dict(self) -> Dict[str, Any]: + return { + "poll_wait_ms": float(self.poll_wait_ms), + "decode_burst": int(self.decode_burst), + "prepare_aging_ms": float(self.prepare_aging_ms), + "finalize_aging_ms": float(self.finalize_aging_ms), + } + + class EngineStatus: NEW = "NEW" QUEUED = "QUEUED" @@ -286,6 +326,146 @@ class EngineRequestState: } +class EngineRequestRegistry: + def __init__(self, recent_limit: int) -> None: + self.lock = threading.Lock() + self.active_requests: Dict[str, EngineRequestState] = {} + self.recent_requests: Deque[EngineRequestState] = deque() + self.recent_limit = max(1, int(recent_limit)) + + def register( + self, + *, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + now = time.perf_counter() + state = EngineRequestState( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=bool(response_streaming), + submit_ts=now, + deadline_ts=deadline_ts, + updated_ts=now, + meta=dict(meta or {}), + lifecycle_timestamps={EngineStatus.NEW: now}, + ) + with self.lock: + self.active_requests[request_id] = state + return state + + def _move_to_recent_locked(self, state: EngineRequestState) -> None: + self.recent_requests.appendleft(state) + while len(self.recent_requests) > self.recent_limit: + self.recent_requests.pop() + + @staticmethod + def _apply_state_extra(state: EngineRequestState, extra: Optional[Dict[str, Any]]) -> None: + if not extra: + return + payload = dict(extra) + backend = payload.pop("backend", None) + if backend is not None: + state.backend = str(backend) + finish_reason = payload.pop("finish_reason", None) + if finish_reason is not None: + state.finish_reason = str(finish_reason) + error = payload.pop("error", None) + if error is not None: + state.error = str(error) + state.profile.update(payload) + + def update(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.get(request_id) + if state is None: + return + state.status = str(status) + state.updated_ts = now + state.lifecycle_timestamps[str(status)] = now + self._apply_state_extra(state, extra) + + def merge_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + if not extra: + return + now = time.perf_counter() + with self.lock: + state = self.active_requests.get(request_id) + if state is None: + for recent_state in self.recent_requests: + if recent_state.request_id == request_id: + state = recent_state + break + if state is None: + return + state.updated_ts = now + self._apply_state_extra(state, extra) + + def complete(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.COMPLETED + state.updated_ts = now + state.lifecycle_timestamps[EngineStatus.COMPLETED] = now + self._apply_state_extra(state, extra) + self._move_to_recent_locked(state) + + def fail(self, request_id: str, error: str) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.FAILED + state.updated_ts = now + state.error = str(error) + state.lifecycle_timestamps[EngineStatus.FAILED] = now + self._move_to_recent_locked(state) + + def snapshot(self) -> Dict[str, Any]: + with self.lock: + active = [state.to_summary() for state in self.active_requests.values()] + recent = [state.to_summary() for state in list(self.recent_requests)] + recent_limit = self.recent_limit + active.sort(key=lambda item: item["submit_ts"]) + return { + "active_count": len(active), + "recent_count": len(recent), + "recent_limit": recent_limit, + "active_requests": active, + "recent_requests": recent, + } + + def collect_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: + requested = set(request_ids) + results: List[Dict[str, Any]] = [] + with self.lock: + for state in self.active_requests.values(): + if state.request_id in requested: + results.append(state.to_summary()) + existing_ids = {item["request_id"] for item in results} + for state in self.recent_requests: + if state.request_id in requested and state.request_id not in existing_ids: + results.append(state.to_summary()) + results.sort(key=lambda item: item["request_id"]) + return results + + def has_active(self, request_id: str) -> bool: + with self.lock: + return request_id in self.active_requests + + @dataclass class SchedulerPendingJob: request_id: str @@ -298,6 +478,8 @@ class SchedulerPendingJob: sample_steps: int media_type: str admission_wait_ms: float = 0.0 + engine_policy_wait_ms: float = 0.0 + engine_dispatch_wait_ms: float = 0.0 prepare_wall_ms: float = 0.0 prepare_profile_total_ms: float = 0.0 first_schedule_time: float | None = None @@ -316,6 +498,584 @@ class SchedulerPendingJob: engine_request_id: str | None = None +class SchedulerJobRegistry: + def __init__(self, lock: threading.Lock | threading.RLock | threading.Condition) -> None: + self._lock = lock + self._job_map: Dict[str, SchedulerPendingJob] = {} + self._total_submitted = 0 + self._total_finished = 0 + + def register(self, job: SchedulerPendingJob, *, keep_job: bool = True) -> None: + with self._lock: + if keep_job: + self._job_map[job.request_id] = job + self._total_submitted += 1 + + def get(self, request_id: str) -> SchedulerPendingJob | None: + with self._lock: + return self._job_map.get(request_id) + + def pop(self, request_id: str) -> SchedulerPendingJob | None: + with self._lock: + return self._job_map.pop(request_id, None) + + def remove(self, request_id: str) -> None: + with self._lock: + self._job_map.pop(request_id, None) + + def mark_finished(self) -> None: + with self._lock: + self._total_finished += 1 + + def mark_finished_and_remove(self, request_id: str) -> None: + with self._lock: + self._job_map.pop(request_id, None) + self._total_finished += 1 + + def is_empty(self) -> bool: + with self._lock: + return not self._job_map + + def submitted_count(self) -> int: + with self._lock: + return int(self._total_submitted) + + def finished_count(self) -> int: + with self._lock: + return int(self._total_finished) + + def snapshot(self, max_request_ids: int = 32) -> Dict[str, Any]: + with self._lock: + request_ids = list(self._job_map.keys()) + return { + "job_count": int(len(request_ids)), + "request_ids": request_ids[: max(0, int(max_request_ids))], + "total_submitted": int(self._total_submitted), + "total_finished": int(self._total_finished), + } + + +class EngineTaskQueueOwner: + def __init__(self, completion_key: str = "total_completed") -> None: + self.condition = threading.Condition() + self.queue: Deque[Any] = deque() + self.total_submitted = 0 + self.total_completed = 0 + self.peak_waiting = 0 + self.completion_key = str(completion_key) + + def enqueue(self, item: Any) -> None: + with self.condition: + self.queue.append(item) + self.total_submitted += 1 + self.peak_waiting = max(self.peak_waiting, len(self.queue)) + self.condition.notify_all() + + def enqueue_many(self, items: Sequence[Any]) -> None: + if not items: + return + with self.condition: + for item in items: + self.queue.append(item) + self.total_submitted += len(items) + self.peak_waiting = max(self.peak_waiting, len(self.queue)) + self.condition.notify_all() + + def pop_left(self) -> Any | None: + with self.condition: + if not self.queue: + return None + return self.queue.popleft() + + def mark_completed(self, count: int = 1, *, notify: bool = False) -> None: + if count <= 0: + return + with self.condition: + self.total_completed += int(count) + if notify: + self.condition.notify_all() + + def has_items(self) -> bool: + with self.condition: + return bool(self.queue) + + def waiting_count(self) -> int: + with self.condition: + return int(len(self.queue)) + + def snapshot(self, *, max_request_ids: int = 16, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + with self.condition: + waiting_items = list(self.queue)[: max(0, int(max_request_ids))] + snapshot = { + "waiting_count": int(len(self.queue)), + "waiting_request_ids": [str(getattr(item, "request_id", "")) for item in waiting_items], + "peak_waiting": int(self.peak_waiting), + "total_submitted": int(self.total_submitted), + self.completion_key: int(self.total_completed), + } + if extra: + snapshot.update(dict(extra)) + return snapshot + + def peek_oldest_age_ms(self, timestamp_attr: str) -> float: + with self.condition: + if not self.queue: + return 0.0 + enqueue_time = float(getattr(self.queue[0], timestamp_attr)) + return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) + + def is_drained(self) -> bool: + with self.condition: + return not self.queue and self.total_submitted == self.total_completed + + def take_finalize_batch( + self, + *, + finalize_mode: str, + batch_max_items: int, + batch_wait_s: float, + use_vocoder: bool, + ) -> List[SchedulerFinalizeTask]: + with self.condition: + if not self.queue: + return [] + selected_tasks = [self.queue.popleft()] + if finalize_mode == "sync" or use_vocoder: + return selected_tasks + if batch_max_items <= 1: + return selected_tasks + first_task = selected_tasks[0] + oldest_age_s = max(0.0, time.perf_counter() - first_task.enqueued_time) + if len(self.queue) + 1 < batch_max_items and oldest_age_s < batch_wait_s: + self.queue.appendleft(first_task) + return [] + while len(selected_tasks) < batch_max_items: + if not self.queue: + break + matched_index = None + for index, task in enumerate(self.queue): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is None: + break + selected_tasks.append(self.queue[matched_index]) + del self.queue[matched_index] + return selected_tasks + + +class EnginePolicyArbiterController: + def __init__( + self, + *, + policy_config: EnginePolicyConfig, + arbiter_config: EngineArbiterConfig, + snapshot_request_registry: Callable[[], Dict[str, Any]], + get_worker_state: Callable[[], Dict[str, Any]], + snapshot_prepare_state: Callable[[], Dict[str, Any]], + snapshot_finalize_state: Callable[[], Dict[str, Any]], + snapshot_dispatch_state: Callable[[], Dict[str, Any]], + snapshot_decode_runtime_state: Callable[[], Dict[str, Any]], + snapshot_job_registry: Callable[[], Dict[str, Any]], + peek_queue_age_ms: Callable[[str], float], + merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], + ) -> None: + self.policy_config = policy_config + self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0) + self.arbiter_config = arbiter_config + self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0) + self.condition = threading.Condition() + self.state = EngineArbiterState( + decode_budget_remaining=int(self.arbiter_config.decode_burst), + last_observed_at=time.perf_counter(), + ) + self.snapshot_request_registry = snapshot_request_registry + self.get_worker_state = get_worker_state + self.snapshot_prepare_state = snapshot_prepare_state + self.snapshot_finalize_state = snapshot_finalize_state + self.snapshot_dispatch_state = snapshot_dispatch_state + self.snapshot_decode_runtime_state = snapshot_decode_runtime_state + self.snapshot_job_registry = snapshot_job_registry + self.peek_queue_age_ms = peek_queue_age_ms + self.merge_request_state_profile = merge_request_state_profile + + def snapshot_state(self) -> Dict[str, Any]: + with self.condition: + return { + "config": self.arbiter_config.to_dict(), + "total_ticks": int(self.state.total_ticks), + "total_idle_ticks": int(self.state.total_idle_ticks), + "total_prepare_dispatches": int(self.state.total_prepare_dispatches), + "total_decode_dispatches": int(self.state.total_decode_dispatches), + "total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks), + "total_finalize_dispatches": int(self.state.total_finalize_dispatches), + "decode_budget_remaining": int(self.state.decode_budget_remaining), + "last_stage": str(self.state.last_stage), + "last_reason": str(self.state.last_reason), + "last_policy_allowed": bool(self.state.last_policy_allowed), + "last_observed_at": float(self.state.last_observed_at), + } + + def notify(self) -> None: + with self.condition: + self.condition.notify_all() + + def wait(self) -> None: + with self.condition: + self.condition.wait(timeout=self.arbiter_poll_s) + + def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + with self.condition: + self.state.total_ticks += 1 + if stage == "idle": + self.state.total_idle_ticks += 1 + elif stage == "prepare": + self.state.total_prepare_dispatches += 1 + self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) + elif stage == "finalize": + self.state.total_finalize_dispatches += 1 + self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) + elif stage == "decode_dispatch": + self.state.total_decode_dispatches += 1 + elif stage == "decode_runtime": + self.state.total_decode_runtime_ticks += 1 + self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1) + self.state.last_stage = str(stage) + self.state.last_reason = str(reason) + self.state.last_policy_allowed = bool(policy_allowed) + self.state.last_observed_at = time.perf_counter() + + def build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + prepare_dispatcher_state = self.snapshot_prepare_state() + finalize_dispatcher_state = self.snapshot_finalize_state() + dispatcher_state = self.snapshot_dispatch_state() + active_requests = list(request_registry.get("active_requests", [])) + status_counts: Dict[str, int] = {} + for item in active_requests: + status = str(item.get("status", "UNKNOWN")) + status_counts[status] = status_counts.get(status, 0) + 1 + + worker_pending_jobs = int(worker_state.get("pending_jobs", 0)) + worker_decode_active_size = int(worker_state.get("running_requests", 0)) + worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0)) + worker_finalize_pending = int(worker_state.get("finalize_pending", 0)) + worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0)) + engine_decode_runtime_state = self.snapshot_decode_runtime_state() + engine_job_registry = self.snapshot_job_registry() + decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0)) + decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0)) + return { + "active_request_count": int(len(active_requests)), + "status_counts": status_counts, + "queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)), + "cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)), + "gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)), + "ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)), + "active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)), + "ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)), + "finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)), + "streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)), + "worker_pending_jobs": worker_pending_jobs, + "worker_decode_active_size": worker_decode_active_size, + "worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)), + "worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)), + "engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs, + "engine_decode_runtime_active_request_count": decode_runtime_active_size, + "engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)), + "engine_job_registry_count": int(engine_job_registry.get("job_count", 0)), + "worker_prepare_inflight": worker_prepare_inflight, + "worker_finalize_pending": worker_finalize_pending, + "worker_finalize_inflight": worker_finalize_inflight, + "engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)), + "engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)), + "engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)), + "decode_backlog": int( + decode_runtime_pending_jobs + decode_runtime_active_size + if bool(worker_state.get("engine_decode_control_enabled", False)) + else worker_pending_jobs + worker_decode_active_size + ), + } + + def build_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + counters = self.build_stage_counters(request_registry, worker_state) + config = self.policy_config.to_dict() + blocked_reasons: List[Dict[str, Any]] = [] + finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0)) + limit_checks = [ + ("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])), + ("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])), + ("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])), + ("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])), + ("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])), + ("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])), + ] + if bool(config["enabled"]): + for name, value, limit in limit_checks: + if limit > 0 and int(value) >= int(limit): + blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)}) + return { + "enabled": bool(config["enabled"]), + "allowed": (not bool(config["enabled"])) or not blocked_reasons, + "blocked_reasons": blocked_reasons, + "config": config, + "metrics": { + "active_request_count": int(counters["active_request_count"]), + "queued_request_count": int(counters["queued_request_count"]), + "ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]), + "active_decode_request_count": int(counters["active_decode_request_count"]), + "engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]), + "engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]), + "decode_backlog": int(counters["decode_backlog"]), + "prepare_inflight": int(counters["worker_prepare_inflight"]), + "finalize_pending": int(finalize_pending_total), + "engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)), + "finalize_inflight": int(counters["worker_finalize_inflight"]), + }, + "observed_at": time.perf_counter(), + } + + async def wait_for_policy_admission( + self, + *, + request_id: str | None, + timeout_sec: float | None, + ) -> tuple[float, Dict[str, Any]]: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + snapshot = self.build_policy_snapshot(request_registry, worker_state) + if not self.policy_config.enabled: + return 0.0, snapshot + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + while True: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + snapshot = self.build_policy_snapshot(request_registry, worker_state) + if snapshot["allowed"]: + wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0) + if request_id not in [None, ""]: + self.merge_request_state_profile( + str(request_id), + { + "engine_policy_wait_ms": float(wait_ms), + "engine_policy_snapshot": snapshot, + }, + ) + return wait_ms, snapshot + now = time.perf_counter() + if deadline is not None and now >= deadline: + blocked_summary = ", ".join( + f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", []) + ) + raise TimeoutError(f"engine policy admission timeout ({blocked_summary})") + await asyncio.sleep(self.policy_poll_s) + + def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) + prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0)) + finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) + decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) + decode_runtime_state = self.snapshot_decode_runtime_state() + worker_decode_has_work = bool(decode_runtime_state.get("has_work", False)) + worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False)) + worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) + worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) + prepare_age_ms = float(self.peek_queue_age_ms("prepare")) + finalize_age_ms = float(self.peek_queue_age_ms("finalize")) + decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) + decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) + policy_allowed = bool(policy_snapshot.get("allowed", True)) + if ( + worker_decode_control_enabled + and worker_decode_has_work + and policy_allowed + and decode_budget_remaining > 0 + and (worker_running_requests > 0 or worker_pending_jobs > 0) + ): + return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state + if ( + worker_decode_control_enabled + and worker_pending_jobs > 0 + and policy_allowed + and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms) + ): + return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state + if ( + decode_waiting > 0 + and policy_allowed + and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) + ): + return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state + if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): + return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): + return "finalize", "finalize_aging", policy_snapshot, worker_state + if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): + return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): + return "prepare", "prepare_aging", policy_snapshot, worker_state + if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: + return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state + if decode_waiting > 0 and policy_allowed: + return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state + if finalize_waiting > 0: + return "finalize", "finalize_fallback", policy_snapshot, worker_state + if prepare_waiting > 0: + return "prepare", "prepare_fallback", policy_snapshot, worker_state + return "idle", "no_pending_work", policy_snapshot, worker_state + + +class EngineDecodeRuntimeOwner: + def __init__( + self, + *, + get_decode_runtime_counters: Callable[[], Dict[str, int]], + get_micro_batch_wait_s: Callable[[], float], + ) -> None: + self.get_decode_runtime_counters = get_decode_runtime_counters + self.get_micro_batch_wait_s = get_micro_batch_wait_s + self.condition = threading.Condition() + self.pending_jobs: Deque[SchedulerPendingJob] = deque() + self.active_batch: T2SActiveBatch | None = None + self.state_lock = threading.Lock() + self.state = EngineDecodeRuntimeState(updated_at=time.perf_counter()) + + @staticmethod + def summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + if active_batch is None: + return {} + decode_step_index_max = 0 + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0: + decode_step_index_max = int(active_batch.step_indices.max().item()) + return { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": int(decode_step_index_max), + } + + def snapshot_pending_queue_state(self) -> Dict[str, Any]: + with self.condition: + return { + "pending_jobs": int(len(self.pending_jobs)), + "pending_request_ids": [job.request_id for job in list(self.pending_jobs)[:32]], + } + + def enqueue_pending_job(self, job: SchedulerPendingJob) -> None: + with self.condition: + self.pending_jobs.append(job) + self.condition.notify_all() + self.refresh_state("engine_decode_pending_enqueue") + + def take_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs: + return [] + if wait_for_batch: + oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) + if (time.perf_counter() - oldest_enqueue_time) < self.get_micro_batch_wait_s(): + return [] + pending_jobs = list(self.pending_jobs) + self.pending_jobs.clear() + self.refresh_state("engine_decode_pending_dequeue") + return pending_jobs + + def pending_age_ms(self) -> float: + with self.condition: + if not self.pending_jobs: + return 0.0 + enqueue_time = float(self.pending_jobs[0].enqueue_time) + return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) + + def has_pending_jobs(self) -> bool: + with self.condition: + return bool(self.pending_jobs) + + def get_active_batch(self) -> T2SActiveBatch | None: + return self.active_batch + + def set_active_batch(self, active_batch: T2SActiveBatch | None) -> None: + self.active_batch = active_batch + + def active_batch_summary(self) -> Dict[str, Any]: + return self.summarize_active_batch(self.active_batch) + + def refresh_state(self, last_event: str) -> None: + pending_state = self.snapshot_pending_queue_state() + active_batch_summary = self.active_batch_summary() + worker_decode_counters = self.get_decode_runtime_counters() + with self.state_lock: + self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) + self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) + self.state.active_request_count = int(active_batch_summary.get("request_count", 0)) + self.state.active_request_ids = list(active_batch_summary.get("request_ids", []))[:32] + self.state.prefill_done = bool(active_batch_summary.get("prefill_done", False)) + self.state.decode_step_index_max = int(active_batch_summary.get("decode_step_index_max", 0)) + self.state.total_cycles = int(worker_decode_counters.get("total_cycles", 0)) + self.state.prefill_cycles = int(worker_decode_counters.get("prefill_cycles", 0)) + self.state.step_cycles = int(worker_decode_counters.get("step_cycles", 0)) + self.state.has_work = bool(pending_state.get("pending_jobs", 0) or active_batch_summary.get("request_count", 0)) + self.state.last_event = str(last_event) + self.state.updated_at = float(time.perf_counter()) + + def update_from_worker_snapshot(self, snapshot: Dict[str, Any]) -> None: + if not snapshot: + return + pending_state = self.snapshot_pending_queue_state() + with self.state_lock: + self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) + self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) + self.state.active_request_count = int(snapshot.get("active_request_count", 0)) + self.state.active_request_ids = list(snapshot.get("active_request_ids", []))[:32] + self.state.prefill_done = bool(snapshot.get("prefill_done", False)) + self.state.decode_step_index_max = int(snapshot.get("decode_step_index_max", 0)) + self.state.total_cycles = int(snapshot.get("total_cycles", 0)) + self.state.prefill_cycles = int(snapshot.get("prefill_cycles", 0)) + self.state.step_cycles = int(snapshot.get("step_cycles", 0)) + self.state.has_work = bool( + pending_state.get("pending_jobs", 0) + or snapshot.get("active_request_count", 0) + or snapshot.get("has_work", False) + ) + self.state.last_event = str(snapshot.get("last_event", "unknown")) + self.state.updated_at = float(snapshot.get("updated_at", time.perf_counter())) + + def snapshot_state(self) -> Dict[str, Any]: + pending_state = self.snapshot_pending_queue_state() + active_batch_summary = self.active_batch_summary() + worker_decode_counters = self.get_decode_runtime_counters() + with self.state_lock: + return { + "pending_jobs": int(pending_state.get("pending_jobs", self.state.pending_jobs)), + "pending_request_ids": list(pending_state.get("pending_request_ids", self.state.pending_request_ids)), + "active_request_count": int(active_batch_summary.get("request_count", self.state.active_request_count)), + "active_request_ids": list(active_batch_summary.get("request_ids", self.state.active_request_ids)), + "prefill_done": bool(active_batch_summary.get("prefill_done", self.state.prefill_done)), + "decode_step_index_max": int( + active_batch_summary.get("decode_step_index_max", self.state.decode_step_index_max) + ), + "total_cycles": int(worker_decode_counters.get("total_cycles", 0)), + "prefill_cycles": int(worker_decode_counters.get("prefill_cycles", 0)), + "step_cycles": int(worker_decode_counters.get("step_cycles", 0)), + "has_work": bool( + pending_state.get("pending_jobs", 0) + or active_batch_summary.get("request_count", self.state.active_request_count) + or self.state.has_work + ), + "last_event": str(self.state.last_event), + "updated_at": float(self.state.updated_at), + } + @dataclass class SchedulerFinalizeTask: request_id: str @@ -323,182 +1083,273 @@ class SchedulerFinalizeTask: enqueued_time: float +@dataclass +class EngineDispatchTask: + request_id: str + state: T2SRequestState + speed_factor: float + sample_steps: int + media_type: str + prepare_wall_ms: float + prepare_profile_total_ms: float + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + engine_request_id: str | None + timeout_sec: float | None + enqueue_time: float + worker_job: SchedulerPendingJob | None = None + engine_policy_wait_ms: float = 0.0 + engine_dispatch_wait_ms: float = 0.0 + engine_policy_snapshot: Dict[str, Any] | None = None + error: str | None = None + + +@dataclass +class EngineGpuPrepareTask: + request_id: str + cpu_stage: PreparedCpuStage + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + engine_request_id: str | None + enqueue_time: float + queue_wait_ms: float = 0.0 + error: str | None = None + + +@dataclass +class EngineFinalizeQueueState: + waiting_count: int + waiting_request_ids: List[str] + peak_waiting: int + total_submitted: int + total_completed: int + + +@dataclass +class EngineArbiterState: + total_ticks: int = 0 + total_idle_ticks: int = 0 + total_prepare_dispatches: int = 0 + total_decode_dispatches: int = 0 + total_decode_runtime_ticks: int = 0 + total_finalize_dispatches: int = 0 + decode_budget_remaining: int = 0 + last_stage: str = "idle" + last_reason: str = "init" + last_observed_at: float = 0.0 + last_policy_allowed: bool = True + + +@dataclass +class EngineDecodeRuntimeState: + pending_jobs: int = 0 + pending_request_ids: List[str] = field(default_factory=list) + active_request_count: int = 0 + active_request_ids: List[str] = field(default_factory=list) + prefill_done: bool = False + decode_step_index_max: int = 0 + total_cycles: int = 0 + prefill_cycles: int = 0 + step_cycles: int = 0 + has_work: bool = False + last_event: str = "init" + updated_at: float = 0.0 + + @dataclass class RuntimeStateCallbacks: update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None fail: Callable[[str, str], None] | None = None + decode_runtime_update: Callable[[Dict[str, Any]], None] | None = None -class UnifiedSchedulerWorker: +class WorkerPrepareExecutor: def __init__( self, tts: TTS, - max_steps: int = 1500, - micro_batch_wait_ms: int = 5, - runtime_callbacks: RuntimeStateCallbacks | None = None, - ): - self.tts = tts - self.max_steps = int(max_steps) - self.micro_batch_wait_s = float(micro_batch_wait_ms) / 1000.0 - self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() - self.prepare_coordinator = PrepareCoordinator(tts) - self.condition = threading.Condition() - self.prepare_inflight = 0 - self.prepare_peak_inflight = 0 - self.finalize_condition = threading.Condition() - self.finalize_pending_tasks: Deque[SchedulerFinalizeTask] = deque() - self.finalize_pending_peak = 0 - self.finalize_inflight = 0 - self.finalize_inflight_peak = 0 - self.finalize_workers = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) - self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() - self.finalize_batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) - self.finalize_batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) - self.decode_backlog_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_DECODE_BACKLOG_MAX", "0"))) - self.finalize_pending_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_FINALIZE_PENDING_MAX", "0"))) - self.pending_jobs: List[SchedulerPendingJob] = [] - self.active_batch: T2SActiveBatch | None = None - self.job_map: Dict[str, SchedulerPendingJob] = {} - self.total_finished = 0 - self.total_submitted = 0 - self.worker_thread = threading.Thread(target=self._run_loop, name="unified-t2s-scheduler-worker", daemon=True) - self.worker_thread.start() - self.finalize_threads = [ - threading.Thread( - target=self._run_finalize_loop, - name=f"unified-t2s-finalize-{worker_index}", - daemon=True, - ) - for worker_index in range(self.finalize_workers) - ] - for finalize_thread in self.finalize_threads: - finalize_thread.start() + on_state_change: Callable[[], None] | None = None, + ) -> None: + self.coordinator = PrepareCoordinator(tts) + self.on_state_change = on_state_change - def _current_decode_backlog_locked(self) -> int: - running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids) - return int(len(self.pending_jobs) + running_requests) + def _notify_state_change(self) -> None: + if self.on_state_change is None: + return + try: + self.on_state_change() + except Exception: + pass - def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]: - decode_backlog = self._current_decode_backlog_locked() - finalize_pending = int(len(self.finalize_pending_tasks)) - prepare_inflight = int(self.prepare_coordinator.snapshot()["inflight"]) - blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max - blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max - return ( - not blocked_decode and not blocked_finalize, - { - "decode_backlog": decode_backlog, - "finalize_pending": finalize_pending, - "prepare_inflight": prepare_inflight, - "decode_backlog_max": int(self.decode_backlog_max), - "finalize_pending_max": int(self.finalize_pending_max), - }, - ) + def snapshot(self) -> Dict[str, int]: + return dict(self.coordinator.snapshot()) - def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]: - start = time.perf_counter() - deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) - last_snapshot: Dict[str, int] = {} - while True: - with self.condition: - allowed, snapshot = self._can_accept_submit_locked() - last_snapshot = snapshot - if allowed: - return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot - if deadline is not None and time.perf_counter() >= deadline: - raise TimeoutError( - "scheduler submit admission timeout " - f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})" - ) - self.condition.wait(timeout=self.micro_batch_wait_s) + def get_max_inflight(self) -> int: + return int(self.coordinator.snapshot().get("max_inflight", 0)) - async def submit_async( + def is_idle(self) -> bool: + return int(self.coordinator.snapshot().get("inflight", 0)) <= 0 + + async def prepare_state_profiled_async( self, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None = None, - done_future: asyncio.Future | None = None, - engine_request_id: str | None = None, - timeout_sec: float | None = None, - ) -> SchedulerPendingJob: - return await asyncio.to_thread( - self.submit, - state, - speed_factor, - sample_steps, - media_type, - prepare_wall_ms, - prepare_profile_total_ms, - done_loop, - done_future, - engine_request_id, - timeout_sec, - ) + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + try: + return await self.coordinator.prepare_state_profiled_async(spec, prepare_submit_at) + finally: + self._notify_state_change() - def snapshot(self) -> dict: + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + results = await asyncio.gather( + *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] + ) + return [state for state, _, _ in results] + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + try: + return await self.coordinator.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + finally: + self._notify_state_change() + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + try: + return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) + finally: + self._notify_state_change() + + +class WorkerFinalizeExecutor: + def __init__( + self, + tts: TTS, + on_state_change: Callable[[], None] | None = None, + external_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, + ) -> None: + self.tts = tts + self.on_state_change = on_state_change + self.external_submit = external_submit + self.condition = threading.Condition() + self.pending_tasks: Deque[SchedulerFinalizeTask] = deque() + self.pending_peak = 0 + self.inflight = 0 + self.inflight_peak = 0 + self.worker_count = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) + self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() + self.batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) + self.batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) + + def _notify_state_change(self) -> None: + if self.on_state_change is None: + return + try: + self.on_state_change() + except Exception: + pass + + def get_worker_count(self) -> int: + return int(self.worker_count) + + def get_batch_policy(self) -> Dict[str, Any]: + return { + "finalize_mode": str(self.finalize_mode), + "finalize_batch_max_items": int(self.batch_max_items), + "finalize_batch_wait_s": float(self.batch_wait_s), + } + + def get_pending_count(self) -> int: + with self.condition: + return int(len(self.pending_tasks)) + + def snapshot(self) -> Dict[str, Any]: with self.condition: - finalize_pending = len(self.finalize_pending_tasks) - prepare_state = self.prepare_coordinator.snapshot() - active_batch = self.active_batch - active_batch_summary = None - if active_batch is not None: - active_batch_summary = { - "request_count": int(len(active_batch.request_ids)), - "request_ids": list(active_batch.request_ids), - "prefill_done": bool(active_batch.prefill_done), - "decode_step_index_max": ( - int(active_batch.step_indices.max().item()) - if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0 - else 0 - ), - } return { - "pending_jobs": len(self.pending_jobs), - "running_requests": 0 if active_batch is None else len(active_batch.request_ids), - "prepare_inflight": prepare_state["inflight"], - "prepare_peak_inflight": prepare_state["peak_inflight"], - "prepare_max_inflight": prepare_state.get("max_inflight", 0), - "prepare_state": dict(prepare_state), - "finalize_pending": finalize_pending, - "finalize_pending_peak": self.finalize_pending_peak, - "finalize_inflight": self.finalize_inflight, - "finalize_inflight_peak": self.finalize_inflight_peak, - "finalize_workers": self.finalize_workers, - "finalize_mode": self.finalize_mode, - "finalize_batch_max_items": self.finalize_batch_max_items, - "finalize_batch_wait_ms": self.finalize_batch_wait_s * 1000.0, - "decode_backlog_max": self.decode_backlog_max, - "finalize_pending_max": self.finalize_pending_max, - "active_batch": active_batch_summary, - "total_submitted": self.total_submitted, - "total_finished": self.total_finished, - "drained": self.is_drained(), + "finalize_pending": int(len(self.pending_tasks)), + "finalize_pending_peak": int(self.pending_peak), + "finalize_inflight": int(self.inflight), + "finalize_inflight_peak": int(self.inflight_peak), + "finalize_workers": int(self.worker_count), + "finalize_mode": str(self.finalize_mode), + "finalize_batch_max_items": int(self.batch_max_items), + "finalize_batch_wait_ms": float(self.batch_wait_s * 1000.0), } - def is_drained(self) -> bool: + def is_idle(self) -> bool: with self.condition: - with self.finalize_condition: - return ( - self.active_batch is None - and not self.pending_jobs - and not self.job_map - and self.prepare_coordinator.snapshot()["inflight"] <= 0 - and self.finalize_inflight <= 0 - and not self.finalize_pending_tasks - ) + return self.inflight <= 0 and not self.pending_tasks - def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: - deadline = time.perf_counter() + max(0.0, timeout_sec) - while time.perf_counter() < deadline: - if self.is_drained(): - return True - time.sleep(poll_interval_sec) - return self.is_drained() + def enqueue_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + if self.external_submit is not None: + self.external_submit(tasks) + self._notify_state_change() + return + with self.condition: + for task in tasks: + self.pending_tasks.append(task) + self.pending_peak = max(self.pending_peak, len(self.pending_tasks)) + self.condition.notify_all() + self._notify_state_change() + + def begin_execution(self, task_count: int) -> None: + if task_count <= 0: + return + with self.condition: + self.inflight += int(task_count) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self.condition.notify_all() + self._notify_state_change() + + def end_execution(self, task_count: int) -> None: + with self.condition: + self.inflight = max(0, self.inflight - int(task_count)) + self.condition.notify_all() + self._notify_state_change() + + def take_task_batch_blocking(self) -> List[SchedulerFinalizeTask]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + selected_tasks = [self.pending_tasks.popleft()] + if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: + self.inflight += len(selected_tasks) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self._notify_state_change() + return selected_tasks + batch_deadline = time.perf_counter() + self.batch_wait_s + while len(selected_tasks) < self.batch_max_items: + if not self.pending_tasks: + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.condition.wait(timeout=remaining) + continue + first_task = selected_tasks[0] + matched_index = None + for index, task in enumerate(self.pending_tasks): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is not None: + selected_tasks.append(self.pending_tasks[matched_index]) + del self.pending_tasks[matched_index] + continue + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.condition.wait(timeout=remaining) + self.inflight += len(selected_tasks) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self._notify_state_change() + return selected_tasks def _sync_device(self) -> None: try: @@ -510,201 +1361,6 @@ class UnifiedSchedulerWorker: except Exception: pass - def submit( - self, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None = None, - done_future: asyncio.Future | None = None, - engine_request_id: str | None = None, - timeout_sec: float | None = None, - ) -> SchedulerPendingJob: - admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec) - job = SchedulerPendingJob( - request_id=state.request_id, - state=state, - done_event=threading.Event(), - done_loop=done_loop, - done_future=done_future, - enqueue_time=time.perf_counter(), - speed_factor=float(speed_factor), - sample_steps=int(sample_steps), - media_type=media_type, - admission_wait_ms=float(admission_wait_ms), - prepare_wall_ms=float(prepare_wall_ms), - prepare_profile_total_ms=float(prepare_profile_total_ms), - engine_request_id=engine_request_id or state.request_id, - ) - with self.condition: - self.pending_jobs.append(job) - self.job_map[job.request_id] = job - self.total_submitted += 1 - self.condition.notify_all() - self._runtime_update( - job.engine_request_id, - EngineStatus.QUEUED, - { - "scheduler_request_id": job.request_id, - "decode_admission_wait_ms": float(admission_wait_ms), - "admission_snapshot": dict(admission_snapshot), - }, - ) - with self.finalize_condition: - self.finalize_condition.notify_all() - return job - - async def prepare_state_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> tuple[T2SRequestState, float, float]: - with self.condition: - self.prepare_inflight += 1 - self.prepare_peak_inflight = max(self.prepare_peak_inflight, self.prepare_inflight) - try: - return await self.prepare_coordinator.prepare_state_profiled_async(spec, prepare_submit_at) - finally: - with self.condition: - self.prepare_inflight = max(0, self.prepare_inflight - 1) - self.condition.notify_all() - with self.finalize_condition: - self.finalize_condition.notify_all() - - async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: - results = await asyncio.gather( - *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] - ) - return [state for state, _, _ in results] - - def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: - with self.condition: - for job in pending_jobs: - tracked_job = self.job_map.get(job.request_id) - if tracked_job is None: - continue - tracked_job.first_schedule_time = float(started_at) - self._runtime_update( - tracked_job.engine_request_id, - EngineStatus.GPU_PREPARING, - {"scheduler_request_id": tracked_job.request_id, "prefill_started_at": float(started_at)}, - ) - - def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_map.get(request_id) - if job is not None: - job.prefill_ms += delta_ms - - def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_map.get(request_id) - if job is not None: - job.merge_ms += delta_ms - - def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - activate_request_ids: List[str] = [] - with self.condition: - for request_id in request_ids: - job = self.job_map.get(request_id) - if job is not None: - if job.decode_steps == 0: - activate_request_ids.append(job.engine_request_id) - job.decode_ms += delta_ms - job.decode_steps += 1 - for engine_request_id in activate_request_ids: - self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None) - - def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_map.get(request_id) - if job is not None: - job.finalize_wait_ms += float(delta_ms) - - def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: - if not items: - return - enqueued_at = time.perf_counter() - with self.finalize_condition: - for item in items: - job = self.job_map.get(item.request_id) - if job is not None: - self._runtime_update( - job.engine_request_id, - EngineStatus.READY_FOR_FINALIZE, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - }, - ) - self.finalize_pending_tasks.append( - SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) - ) - self.finalize_pending_peak = max(self.finalize_pending_peak, len(self.finalize_pending_tasks)) - self.finalize_condition.notify_all() - - def _take_finalize_task_batch(self) -> List[SchedulerFinalizeTask]: - with self.finalize_condition: - while not self.finalize_pending_tasks: - self.finalize_condition.wait() - selected_tasks = [self.finalize_pending_tasks.popleft()] - if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: - self.finalize_inflight += len(selected_tasks) - self.finalize_inflight_peak = max(self.finalize_inflight_peak, self.finalize_inflight) - return selected_tasks - batch_deadline = time.perf_counter() + self.finalize_batch_wait_s - while len(selected_tasks) < self.finalize_batch_max_items: - if not self.finalize_pending_tasks: - remaining = batch_deadline - time.perf_counter() - if remaining <= 0: - break - self.finalize_condition.wait(timeout=remaining) - continue - first_task = selected_tasks[0] - matched_index = None - for index, task in enumerate(self.finalize_pending_tasks): - if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: - matched_index = index - break - if matched_index is not None: - selected_tasks.append(self.finalize_pending_tasks[matched_index]) - del self.finalize_pending_tasks[matched_index] - continue - remaining = batch_deadline - time.perf_counter() - if remaining <= 0: - break - self.finalize_condition.wait(timeout=remaining) - self.finalize_inflight += len(selected_tasks) - self.finalize_inflight_peak = max(self.finalize_inflight_peak, self.finalize_inflight) - with self.condition: - self.condition.notify_all() - return selected_tasks - - def _finalize_task_done(self, count: int) -> None: - with self.finalize_condition: - self.finalize_inflight = max(0, self.finalize_inflight - count) - self.finalize_condition.notify_all() - with self.condition: - self.condition.notify_all() - def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: audio_fragment = self.tts.synthesize_audio_request_local( semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), @@ -772,93 +1428,36 @@ class UnifiedSchedulerWorker: ) return results - def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: - finished_at = time.perf_counter() - with self.condition: - if self.job_map.get(item.request_id) is not job: - return - queue_wait_ms = 0.0 - if job.first_schedule_time is not None: - queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) - worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) - worker_residual_ms = max( - 0.0, - worker_total_ms - - queue_wait_ms - - job.prefill_ms - - job.merge_ms - - job.decode_ms - - job.finalize_wait_ms - - job.synth_ms, - ) - worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) - job.sample_rate = int(sample_rate) - job.audio_data = audio_data - job.result_ready_time = finished_at - prepare_profile = dict(job.state.prepare_profile) - job.result = { - "request_id": item.request_id, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "finish_reason": item.finish_reason, - "decode_admission_wait_ms": float(job.admission_wait_ms), - "prepare_ms": job.prepare_wall_ms, - "prepare_wall_ms": job.prepare_wall_ms, - "prepare_profile_total_ms": job.prepare_profile_total_ms, - "prepare_profile": prepare_profile, - "queue_wait_ms": queue_wait_ms, - "prefill_ms": job.prefill_ms, - "merge_ms": job.merge_ms, - "decode_ms": job.decode_ms, - "finalize_wait_ms": job.finalize_wait_ms, - "synth_ms": job.synth_ms, - "worker_residual_ms": worker_residual_ms, - "worker_other_ms": worker_other_ms, - "worker_total_ms": worker_total_ms, - "decode_steps": int(job.decode_steps), - "sample_rate": int(sample_rate), - "media_type": job.media_type, - } - job.done_event.set() - self._notify_done_future(job) - self.job_map.pop(item.request_id, None) - self.total_finished += 1 - self.condition.notify_all() - self._runtime_complete( - job.engine_request_id, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "sample_rate": int(sample_rate), - "worker_profile": dict(job.result or {}), - }, - ) + def synthesize_finalize_jobs( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> tuple[float, List[tuple[int, np.ndarray]]]: + if not jobs_and_items: + return 0.0, [] + self._sync_device() + synth_start = time.perf_counter() + if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: + job, item = jobs_and_items[0] + batch_results = [self._synthesize_finished_audio(job, item)] + else: + batch_results = self._synthesize_finished_audio_batch(jobs_and_items) + self._sync_device() + synth_ms = (time.perf_counter() - synth_start) * 1000.0 + return float(synth_ms), batch_results - def _finalize_error(self, request_ids: List[str], error: str) -> None: - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_map.get(request_id) - if job is None: - continue - job.error = error - job.done_event.set() - self._notify_done_future(job) - self.job_map.pop(request_id, None) - self.total_finished += 1 - self._runtime_fail(job.engine_request_id, error) - self.condition.notify_all() + +class WorkerCompletionBridge: + def __init__(self, runtime_callbacks: RuntimeStateCallbacks | None = None) -> None: + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() @staticmethod def _resolve_done_future(job: SchedulerPendingJob) -> None: future = job.done_future if future is None or future.done(): return - future.set_result(True) + future.set_result(job) - def _notify_done_future(self, job: SchedulerPendingJob) -> None: + def notify_done_future(self, job: SchedulerPendingJob) -> None: if job.done_loop is None or job.done_future is None: return try: @@ -866,22 +1465,409 @@ class UnifiedSchedulerWorker: except RuntimeError: pass - def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None: - if request_id is None or self.runtime_callbacks.update is None: - return - self.runtime_callbacks.update(request_id, status, extra) - - def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + def runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: if request_id is None or self.runtime_callbacks.complete is None: return self.runtime_callbacks.complete(request_id, extra) - def _runtime_fail(self, request_id: str | None, error: str) -> None: + def runtime_fail(self, request_id: str | None, error: str) -> None: if request_id is None or self.runtime_callbacks.fail is None: return self.runtime_callbacks.fail(request_id, error) - def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + @staticmethod + def build_completed_job_result( + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + finished_at: float | None = None, + ) -> Dict[str, Any]: + finished_at = float(time.perf_counter() if finished_at is None else finished_at) + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + worker_residual_ms = max( + 0.0, + worker_total_ms + - queue_wait_ms + - job.prefill_ms + - job.merge_ms + - job.decode_ms + - job.finalize_wait_ms + - job.synth_ms, + ) + worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + job.result_ready_time = finished_at + prepare_profile = dict(job.state.prepare_profile) + result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "decode_admission_wait_ms": float(job.admission_wait_ms), + "engine_policy_wait_ms": float(job.engine_policy_wait_ms), + "engine_dispatch_wait_ms": float(job.engine_dispatch_wait_ms), + "prepare_ms": job.prepare_wall_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile_total_ms": job.prepare_profile_total_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "merge_ms": job.merge_ms, + "decode_ms": job.decode_ms, + "finalize_wait_ms": job.finalize_wait_ms, + "synth_ms": job.synth_ms, + "worker_residual_ms": worker_residual_ms, + "worker_other_ms": worker_other_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.result = result + return result + + @staticmethod + def build_runtime_complete_payload( + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + ) -> Dict[str, Any]: + return { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "sample_rate": int(sample_rate), + "worker_profile": dict(job.result or {}), + } + + def complete_job( + self, + job: SchedulerPendingJob, + *, + runtime_request_id: str | None, + runtime_extra: Optional[Dict[str, Any]] = None, + remove_job: Callable[[], None] | None = None, + on_job_finished: Callable[[], None] | None = None, + notify_waiters: Callable[[], None] | None = None, + ) -> None: + job.done_event.set() + self.notify_done_future(job) + if remove_job is not None: + remove_job() + if on_job_finished is not None: + on_job_finished() + if notify_waiters is not None: + notify_waiters() + self.runtime_complete(runtime_request_id, runtime_extra) + + def fail_job( + self, + job: SchedulerPendingJob, + *, + error: str, + remove_job: Callable[[], None] | None = None, + on_job_finished: Callable[[], None] | None = None, + notify_waiters: Callable[[], None] | None = None, + ) -> None: + job.error = str(error) + job.done_event.set() + self.notify_done_future(job) + if remove_job is not None: + remove_job() + if on_job_finished is not None: + on_job_finished() + if notify_waiters is not None: + notify_waiters() + self.runtime_fail(job.engine_request_id, str(error)) + + def complete_finalize_task( + self, + *, + condition: threading.Condition, + job_registry: SchedulerJobRegistry, + job: SchedulerPendingJob, + item: T2SFinishedItem, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + runtime_extra: Optional[Dict[str, Any]] = None + with condition: + if job_registry.get(item.request_id) is not job: + return + self.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) + runtime_extra = self.build_runtime_complete_payload(job, item, sample_rate=sample_rate) + self.complete_job( + job, + runtime_request_id=job.engine_request_id, + runtime_extra=runtime_extra, + on_job_finished=lambda: job_registry.mark_finished_and_remove(item.request_id), + notify_waiters=condition.notify_all, + ) + + def fail_jobs( + self, + *, + condition: threading.Condition, + job_registry: SchedulerJobRegistry, + request_ids: List[str], + error: str, + ) -> None: + if not request_ids: + return + with condition: + for request_id in request_ids: + job = job_registry.get(request_id) + if job is None: + continue + self.fail_job( + job, + error=error, + on_job_finished=lambda rid=request_id: job_registry.mark_finished_and_remove(rid), + ) + condition.notify_all() + + +class WorkerDecodeExecutor: + def __init__(self, tts: TTS, max_steps: int) -> None: + self.tts = tts + self.max_steps = int(max_steps) + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def execute_prefill_merge( + self, + *, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], + add_prefill_time: Callable[[List[str], float], None] | None, + add_merge_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + if not pending_jobs: + return { + "executed": False, + "active_batch": active_batch, + "pending_jobs": [], + "prefill_elapsed_s": 0.0, + "merge_elapsed_s": 0.0, + "finished_items": [], + "error": None, + "error_request_ids": [], + } + admitted_finished: List[T2SFinishedItem] = [] + prefill_elapsed_s = 0.0 + merge_elapsed_s = 0.0 + error: str | None = None + error_request_ids: List[str] = [] + try: + self._sync_device() + prefill_start = time.perf_counter() + mark_prefill_started(pending_jobs, prefill_start) + admitted_active_batch, admitted_finished = run_prefill_active_batch( + self.tts.t2s_model.model, + [job.state for job in pending_jobs], + max_steps=self.max_steps, + ) + self._sync_device() + prefill_elapsed_s = time.perf_counter() - prefill_start + if add_prefill_time is not None: + add_prefill_time([job.request_id for job in pending_jobs], prefill_elapsed_s) + if enqueue_finished is not None: + enqueue_finished(admitted_finished) + merge_start = time.perf_counter() + active_batch = merge_active_batches( + self.tts.t2s_model.model, + active_batch, + admitted_active_batch, + ) + merge_elapsed_s = time.perf_counter() - merge_start + if add_merge_time is not None: + add_merge_time( + [] if active_batch is None else list(active_batch.request_ids), + merge_elapsed_s, + ) + except Exception as exc: + error = str(exc) + error_request_ids = [job.request_id for job in pending_jobs] + if finalize_error is not None: + finalize_error(error_request_ids, error) + return { + "executed": True, + "active_batch": active_batch, + "pending_jobs": list(pending_jobs), + "prefill_elapsed_s": float(prefill_elapsed_s), + "merge_elapsed_s": float(merge_elapsed_s), + "finished_items": list(admitted_finished), + "error": error, + "error_request_ids": error_request_ids, + } + + def execute_decode_step( + self, + *, + active_batch: Optional[T2SActiveBatch], + add_decode_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + if active_batch is None: + return { + "executed": False, + "active_batch": None, + "request_ids": [], + "decode_elapsed_s": 0.0, + "finished_items": [], + "error": None, + "error_request_ids": [], + } + active_request_ids: List[str] = [] + step_finished: List[T2SFinishedItem] = [] + decode_elapsed_s = 0.0 + error: str | None = None + error_request_ids: List[str] = [] + try: + active_request_ids = [state.request_id for state in active_batch.states] + self._sync_device() + decode_start = time.perf_counter() + active_batch, step_finished = decode_one_step( + self.tts.t2s_model.model, + active_batch, + max_steps=self.max_steps, + ) + self._sync_device() + decode_elapsed_s = time.perf_counter() - decode_start + if add_decode_time is not None: + add_decode_time(active_request_ids, decode_elapsed_s) + if enqueue_finished is not None: + enqueue_finished(step_finished) + except Exception as exc: + error = str(exc) + error_request_ids = list(active_request_ids) + if finalize_error is not None: + finalize_error(error_request_ids, error) + active_batch = None + return { + "executed": True, + "active_batch": active_batch, + "request_ids": active_request_ids, + "decode_elapsed_s": float(decode_elapsed_s), + "finished_items": list(step_finished), + "error": error, + "error_request_ids": error_request_ids, + } + + def execute_decode_cycle( + self, + *, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], + add_prefill_time: Callable[[List[str], float], None] | None, + add_merge_time: Callable[[List[str], float], None] | None, + add_decode_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + result = { + "executed": False, + "prefill_merge_executed": False, + "decode_step_executed": False, + "active_batch": active_batch, + "prefill_phase": {}, + "decode_phase": {}, + } + prefill_phase = self.execute_prefill_merge( + pending_jobs=list(pending_jobs), + active_batch=result["active_batch"], + mark_prefill_started=mark_prefill_started, + add_prefill_time=add_prefill_time, + add_merge_time=add_merge_time, + enqueue_finished=enqueue_finished, + finalize_error=finalize_error, + ) + prefill_executed = bool(prefill_phase.get("executed", False)) + result["prefill_phase"] = prefill_phase + result["active_batch"] = prefill_phase.get("active_batch") + if prefill_executed: + result["executed"] = True + result["prefill_merge_executed"] = True + decode_phase = self.execute_decode_step( + active_batch=result["active_batch"], + add_decode_time=add_decode_time, + enqueue_finished=enqueue_finished, + finalize_error=finalize_error, + ) + decode_executed = bool(decode_phase.get("executed", False)) + result["decode_phase"] = decode_phase + result["active_batch"] = decode_phase.get("active_batch") + if decode_executed: + result["executed"] = True + result["decode_step_executed"] = True + return result + + +class WorkerDecodeLegacyShell: + def __init__(self, condition: threading.Condition, micro_batch_wait_s: float) -> None: + self.condition = condition + self.micro_batch_wait_s = float(micro_batch_wait_s) + self.pending_jobs: List[SchedulerPendingJob] = [] + self.active_batch: T2SActiveBatch | None = None + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any] | None: + if active_batch is None: + return None + return { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": ( + int(active_batch.step_indices.max().item()) + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0 + else 0 + ), + } + + def current_backlog_locked(self) -> int: + running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids) + return int(len(self.pending_jobs) + running_requests) + + def enqueue_pending_job_locked(self, job: SchedulerPendingJob) -> None: + self.pending_jobs.append(job) + + def snapshot_locked(self) -> Dict[str, Any]: + active_batch_summary = self._summarize_active_batch(self.active_batch) + executor_local_pending_jobs = int(len(self.pending_jobs)) + executor_local_running_requests = 0 if self.active_batch is None else int(len(self.active_batch.request_ids)) + executor_local_has_work = bool(self.pending_jobs or self.active_batch is not None) + return { + "executor_local_pending_jobs": executor_local_pending_jobs, + "executor_local_running_requests": executor_local_running_requests, + "executor_local_has_work": executor_local_has_work, + "executor_local_active_batch": active_batch_summary, + } + + def is_idle_locked(self) -> bool: + return self.active_batch is None and not self.pending_jobs + + def take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: with self.condition: if not self.pending_jobs and self.active_batch is None: self.condition.wait(timeout=self.micro_batch_wait_s) @@ -893,104 +1879,781 @@ class UnifiedSchedulerWorker: self.pending_jobs.clear() return pending - def _run_finalize_loop(self) -> None: + def take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs: + return [] + if wait_for_batch: + oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) + if (time.perf_counter() - oldest_enqueue_time) < self.micro_batch_wait_s: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def has_decode_runtime_work(self) -> bool: + with self.condition: + return bool(self.pending_jobs or self.active_batch is not None) + + def build_runtime_summary_locked(self, *, total_cycles: int, prefill_cycles: int, step_cycles: int, last_event: str) -> Dict[str, Any]: + active_request_ids = [] if self.active_batch is None else list(self.active_batch.request_ids) + decode_step_index_max = 0 + prefill_done = False + if self.active_batch is not None: + prefill_done = bool(self.active_batch.prefill_done) + if self.active_batch.step_indices is not None and self.active_batch.step_indices.numel() > 0: + decode_step_index_max = int(self.active_batch.step_indices.max().item()) + return { + "pending_jobs": int(len(self.pending_jobs)), + "active_request_count": int(len(active_request_ids)), + "active_request_ids": active_request_ids[:32], + "prefill_done": bool(prefill_done), + "decode_step_index_max": int(decode_step_index_max), + "total_cycles": int(total_cycles), + "prefill_cycles": int(prefill_cycles), + "step_cycles": int(step_cycles), + "has_work": bool(self.pending_jobs or self.active_batch is not None), + "last_event": str(last_event), + "updated_at": float(time.perf_counter()), + } + + def run_prefill_merge_once_nonblocking( + self, + *, + external_pending_jobs: Optional[List[SchedulerPendingJob]], + external_active_batch: Optional[T2SActiveBatch], + execute_prefill_merge: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], + ) -> Dict[str, Any]: + pending_jobs = ( + list(external_pending_jobs) + if external_pending_jobs is not None + else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) + ) + active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch + result = execute_prefill_merge(pending_jobs, active_batch) + if external_pending_jobs is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + return result + + def run_decode_step_once_nonblocking( + self, + *, + external_active_batch: Optional[T2SActiveBatch], + execute_decode_step: Callable[[Optional[T2SActiveBatch]], Dict[str, Any]], + ) -> Dict[str, Any]: + active_batch = self.active_batch if external_active_batch is None else external_active_batch + result = execute_decode_step(active_batch) + if external_active_batch is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + return result + + def run_decode_cycle_nonblocking( + self, + *, + external_pending_jobs: Optional[List[SchedulerPendingJob]], + external_active_batch: Optional[T2SActiveBatch], + execute_decode_cycle: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], + on_cycle_executed: Callable[[Dict[str, Any]], None] | None, + ) -> Dict[str, Any]: + pending_jobs = ( + list(external_pending_jobs) + if external_pending_jobs is not None + else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) + ) + active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch + result = execute_decode_cycle(pending_jobs, active_batch) + if external_pending_jobs is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + if result.get("executed") and on_cycle_executed is not None: + on_cycle_executed(result) + return result + + def run_loop( + self, + *, + run_decode_cycle_nonblocking: Callable[[], Dict[str, Any]], + ) -> None: while True: - tasks = self._take_finalize_task_batch() - try: - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + executed = run_decode_cycle_nonblocking() + if executed.get("executed"): + continue + wait_for_batch = self.active_batch is None + pending_jobs = self.take_pending_snapshot(wait_for_batch=wait_for_batch) + if pending_jobs: with self.condition: - for task in tasks: - job = self.job_map.get(task.request_id) - if job is None: - continue - jobs_and_items.append((job, task.item)) - if not jobs_and_items: - continue - now = time.perf_counter() - for task in tasks: - self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) - for job, item in jobs_and_items: + self.pending_jobs = pending_jobs + self.pending_jobs + self.condition.notify_all() + continue + time.sleep(self.micro_batch_wait_s) + + +class WorkerDecodeRuntimeTracker: + def __init__( + self, + runtime_callbacks: RuntimeStateCallbacks | None = None, + ) -> None: + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + self.total_cycles = 0 + self.prefill_cycles = 0 + self.step_cycles = 0 + + def get_counters(self) -> Dict[str, int]: + return { + "total_cycles": int(self.total_cycles), + "prefill_cycles": int(self.prefill_cycles), + "step_cycles": int(self.step_cycles), + } + + def record_cycle(self, result: Dict[str, Any]) -> None: + if not bool(result.get("executed")): + return + self.total_cycles += 1 + if bool(result.get("prefill_merge_executed")): + self.prefill_cycles += 1 + if bool(result.get("decode_step_executed")): + self.step_cycles += 1 + + def build_runtime_summary_locked( + self, + *, + legacy_shell: WorkerDecodeLegacyShell, + last_event: str, + ) -> Dict[str, Any]: + return legacy_shell.build_runtime_summary_locked( + total_cycles=int(self.total_cycles), + prefill_cycles=int(self.prefill_cycles), + step_cycles=int(self.step_cycles), + last_event=str(last_event), + ) + + def notify_runtime_update_locked( + self, + *, + legacy_shell: WorkerDecodeLegacyShell, + last_event: str, + ) -> None: + if self.runtime_callbacks.decode_runtime_update is None: + return + snapshot = self.build_runtime_summary_locked( + legacy_shell=legacy_shell, + last_event=last_event, + ) + self.runtime_callbacks.decode_runtime_update(snapshot) + + +class UnifiedSchedulerWorker: + def __init__( + self, + tts: TTS, + max_steps: int = 1500, + micro_batch_wait_ms: int = 5, + runtime_callbacks: RuntimeStateCallbacks | None = None, + external_finalize_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, + ): + self.tts = tts + self.max_steps = int(max_steps) + self.micro_batch_wait_s = float(micro_batch_wait_ms) / 1000.0 + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + self.condition = threading.Condition() + self.completion_bridge = WorkerCompletionBridge(self.runtime_callbacks) + self.decode_executor = WorkerDecodeExecutor(tts, max_steps=max_steps) + self.decode_legacy_shell = WorkerDecodeLegacyShell(self.condition, self.micro_batch_wait_s) + self.decode_runtime_tracker = WorkerDecodeRuntimeTracker(self.runtime_callbacks) + self.prepare_executor = WorkerPrepareExecutor(tts, on_state_change=self._notify_worker_state_change) + self.finalize_executor = WorkerFinalizeExecutor( + tts, + on_state_change=self._notify_worker_state_change, + external_submit=external_finalize_submit, + ) + self.decode_backlog_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_DECODE_BACKLOG_MAX", "0"))) + self.finalize_pending_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_FINALIZE_PENDING_MAX", "0"))) + self.engine_decode_control_enabled = ( + str(os.environ.get("GPTSOVITS_ENGINE_DRIVE_DECODE", "0")).strip().lower() in {"1", "true", "yes", "on"} + ) + self.job_registry = SchedulerJobRegistry(self.condition) + self.worker_thread: threading.Thread | None = None + if not self.engine_decode_control_enabled: + self.worker_thread = threading.Thread(target=self._run_loop, name="unified-t2s-scheduler-worker", daemon=True) + self.worker_thread.start() + self.finalize_threads = [] + if external_finalize_submit is None: + self.finalize_threads = [ + threading.Thread( + target=self._run_finalize_loop, + name=f"unified-t2s-finalize-{worker_index}", + daemon=True, + ) + for worker_index in range(self.finalize_executor.get_worker_count()) + ] + for finalize_thread in self.finalize_threads: + finalize_thread.start() + + def _notify_worker_state_change(self) -> None: + with self.condition: + self.condition.notify_all() + + def _current_decode_backlog_locked(self) -> int: + return self.decode_legacy_shell.current_backlog_locked() + + def get_micro_batch_wait_s(self) -> float: + return float(self.micro_batch_wait_s) + + def is_engine_decode_control_enabled(self) -> bool: + return bool(self.engine_decode_control_enabled) + + def get_prepare_max_inflight(self) -> int: + return int(self.prepare_executor.get_max_inflight()) + + def get_capacity_limits(self) -> Dict[str, int]: + return { + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + } + + def get_finalize_batch_policy(self) -> Dict[str, Any]: + return dict(self.finalize_executor.get_batch_policy()) + + def get_decode_runtime_counters(self) -> Dict[str, int]: + with self.condition: + return self.decode_runtime_tracker.get_counters() + + def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]: + decode_backlog = self._current_decode_backlog_locked() + finalize_pending = int(self.finalize_executor.get_pending_count()) + prepare_inflight = int(self.prepare_executor.snapshot()["inflight"]) + blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max + blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max + return ( + not blocked_decode and not blocked_finalize, + { + "decode_backlog": decode_backlog, + "finalize_pending": finalize_pending, + "prepare_inflight": prepare_inflight, + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + }, + ) + + def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]: + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + last_snapshot: Dict[str, int] = {} + while True: + with self.condition: + allowed, snapshot = self._can_accept_submit_locked() + last_snapshot = snapshot + if allowed: + return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot + if deadline is not None and time.perf_counter() >= deadline: + raise TimeoutError( + "scheduler submit admission timeout " + f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})" + ) + self.condition.wait(timeout=self.micro_batch_wait_s) + + def _admission_snapshot_locked(self) -> Dict[str, int]: + _, snapshot = self._can_accept_submit_locked() + return snapshot + + async def submit_async( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + skip_capacity_wait: bool = False, + admission_wait_ms_override: float | None = None, + admission_snapshot_override: Dict[str, Any] | None = None, + engine_policy_wait_ms: float = 0.0, + engine_dispatch_wait_ms: float = 0.0, + enqueue_pending: bool = True, + ) -> SchedulerPendingJob: + return await asyncio.to_thread( + self.submit, + state, + speed_factor, + sample_steps, + media_type, + prepare_wall_ms, + prepare_profile_total_ms, + done_loop, + done_future, + engine_request_id, + timeout_sec, + skip_capacity_wait, + admission_wait_ms_override, + admission_snapshot_override, + engine_policy_wait_ms, + engine_dispatch_wait_ms, + enqueue_pending, + ) + + def snapshot(self) -> dict: + with self.condition: + prepare_state = self.prepare_executor.snapshot() + finalize_state = self.finalize_executor.snapshot() + shell_state = self.decode_legacy_shell.snapshot_locked() + decode_runtime_counters = self.decode_runtime_tracker.get_counters() + engine_owned_decode_state = bool(self.engine_decode_control_enabled) + active_batch_summary = shell_state.get("executor_local_active_batch") + executor_local_pending_jobs = int(shell_state.get("executor_local_pending_jobs", 0)) + executor_local_running_requests = int(shell_state.get("executor_local_running_requests", 0)) + executor_local_has_work = bool(shell_state.get("executor_local_has_work", False)) + return { + "pending_jobs": 0 if engine_owned_decode_state else executor_local_pending_jobs, + "running_requests": 0 if engine_owned_decode_state else executor_local_running_requests, + "engine_decode_control_enabled": bool(self.engine_decode_control_enabled), + "legacy_state_owner_mode": not engine_owned_decode_state, + "decode_state_owner": "engine" if engine_owned_decode_state else "worker", + "decode_runtime_has_work": False if engine_owned_decode_state else executor_local_has_work, + "executor_local_pending_jobs": executor_local_pending_jobs, + "executor_local_running_requests": executor_local_running_requests, + "executor_local_has_work": executor_local_has_work, + "decode_runtime_total_cycles": int(decode_runtime_counters.get("total_cycles", 0)), + "decode_runtime_prefill_cycles": int(decode_runtime_counters.get("prefill_cycles", 0)), + "decode_runtime_step_cycles": int(decode_runtime_counters.get("step_cycles", 0)), + "prepare_inflight": prepare_state["inflight"], + "prepare_peak_inflight": prepare_state["peak_inflight"], + "prepare_max_inflight": prepare_state.get("max_inflight", 0), + "prepare_state": dict(prepare_state), + **finalize_state, + "decode_backlog_max": self.decode_backlog_max, + "finalize_pending_max": self.finalize_pending_max, + "active_batch": {} if engine_owned_decode_state else active_batch_summary, + "executor_local_active_batch": active_batch_summary if engine_owned_decode_state else None, + "total_submitted": self.job_registry.submitted_count(), + "total_finished": self.job_registry.finished_count(), + "drained": self.is_drained(), + } + + def is_drained(self) -> bool: + with self.condition: + return ( + self.decode_legacy_shell.is_idle_locked() + and self.job_registry.is_empty() + and self.prepare_executor.is_idle() + and self.finalize_executor.is_idle() + ) + + def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: + deadline = time.perf_counter() + max(0.0, timeout_sec) + while time.perf_counter() < deadline: + if self.is_drained(): + return True + time.sleep(poll_interval_sec) + return self.is_drained() + + def submit( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + skip_capacity_wait: bool = False, + admission_wait_ms_override: float | None = None, + admission_snapshot_override: Dict[str, Any] | None = None, + engine_policy_wait_ms: float = 0.0, + engine_dispatch_wait_ms: float = 0.0, + enqueue_pending: bool = True, + ) -> SchedulerPendingJob: + if skip_capacity_wait: + with self.condition: + admission_snapshot = ( + dict(admission_snapshot_override) + if admission_snapshot_override is not None + else dict(self._admission_snapshot_locked()) + ) + admission_wait_ms = 0.0 if admission_wait_ms_override is None else float(admission_wait_ms_override) + else: + admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec) + job = SchedulerPendingJob( + request_id=state.request_id, + state=state, + done_event=threading.Event(), + done_loop=done_loop, + done_future=done_future, + enqueue_time=time.perf_counter(), + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + admission_wait_ms=float(admission_wait_ms), + engine_policy_wait_ms=float(engine_policy_wait_ms), + engine_dispatch_wait_ms=float(engine_dispatch_wait_ms), + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + engine_request_id=engine_request_id or state.request_id, + ) + with self.condition: + self.job_registry.register(job, keep_job=not self.engine_decode_control_enabled) + if enqueue_pending: + self.decode_legacy_shell.enqueue_pending_job_locked(job) + self.condition.notify_all() + if enqueue_pending: + self._notify_decode_runtime_state("submit") + self._runtime_update( + job.engine_request_id, + EngineStatus.QUEUED, + { + "scheduler_request_id": job.request_id, + "decode_admission_wait_ms": float(admission_wait_ms), + "engine_policy_wait_ms": float(engine_policy_wait_ms), + "engine_dispatch_wait_ms": float(engine_dispatch_wait_ms), + "admission_snapshot": dict(admission_snapshot), + }, + ) + return job + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_executor.prepare_state_profiled_async(spec, prepare_submit_at) + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + return await self.prepare_executor.prepare_states_batch_async(specs) + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + return await self.prepare_executor.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage) + + def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: + with self.condition: + for job in pending_jobs: + job.first_schedule_time = float(started_at) + self._runtime_update( + job.engine_request_id, + EngineStatus.GPU_PREPARING, + {"scheduler_request_id": job.request_id, "prefill_started_at": float(started_at)}, + ) + + def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.prefill_ms += delta_ms + + def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + activate_request_ids: List[str] = [] + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) + job.decode_ms += delta_ms + job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None) + + def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.finalize_wait_ms += float(delta_ms) + + def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + tasks: List[SchedulerFinalizeTask] = [] + with self.condition: + for item in items: + job = self.job_registry.get(item.request_id) + if job is not None: self._runtime_update( job.engine_request_id, - EngineStatus.FINALIZING, + EngineStatus.READY_FOR_FINALIZE, { "finish_reason": item.finish_reason, "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), }, ) - self._sync_device() - synth_start = time.perf_counter() - if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: - job, item = jobs_and_items[0] - batch_results = [self._synthesize_finished_audio(job, item)] - else: - batch_results = self._synthesize_finished_audio_batch(jobs_and_items) - self._sync_device() - synth_ms = (time.perf_counter() - synth_start) * 1000.0 - with self.condition: - for job, _ in jobs_and_items: - tracked_job = self.job_map.get(job.request_id) - if tracked_job is not None: - tracked_job.synth_ms += synth_ms - for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): - self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) - except Exception as exc: - self._finalize_error([task.request_id for task in tasks], str(exc)) - finally: - self._finalize_task_done(len(tasks)) + tasks.append(SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at)) + self.finalize_executor.enqueue_tasks(tasks) + + def begin_finalize_execution(self, task_count: int) -> None: + self.finalize_executor.begin_execution(task_count) + + def end_finalize_execution(self, task_count: int) -> None: + self.finalize_executor.end_execution(task_count) + + def record_external_job_done(self, request_id: str) -> None: + with self.condition: + self.job_registry.mark_finished_and_remove(request_id) + self.condition.notify_all() + + def synthesize_finalize_jobs( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> tuple[float, List[tuple[int, np.ndarray]]]: + return self.finalize_executor.synthesize_finalize_jobs(jobs_and_items) + + def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: + self.completion_bridge.complete_finalize_task( + condition=self.condition, + job_registry=self.job_registry, + job=job, + item=item, + sample_rate=sample_rate, + audio_data=audio_data, + ) + + def _finalize_error(self, request_ids: List[str], error: str) -> None: + self.completion_bridge.fail_jobs( + condition=self.condition, + job_registry=self.job_registry, + request_ids=request_ids, + error=error, + ) + + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(job) + + def _notify_done_future(self, job: SchedulerPendingJob) -> None: + self.completion_bridge.notify_done_future(job) + + def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + if request_id is None or self.runtime_callbacks.update is None: + return + self.runtime_callbacks.update(request_id, status, extra) + + def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + self.completion_bridge.runtime_complete(request_id, extra) + + def _runtime_fail(self, request_id: str | None, error: str) -> None: + self.completion_bridge.runtime_fail(request_id, error) + + def _build_decode_runtime_summary_locked(self, last_event: str) -> Dict[str, Any]: + return self.decode_runtime_tracker.build_runtime_summary_locked( + legacy_shell=self.decode_legacy_shell, + last_event=str(last_event), + ) + + def _notify_decode_runtime_state(self, last_event: str) -> None: + with self.condition: + self.decode_runtime_tracker.notify_runtime_update_locked( + legacy_shell=self.decode_legacy_shell, + last_event=str(last_event), + ) + + def _record_decode_runtime_cycle(self, result: Dict[str, Any]) -> None: + with self.condition: + self.decode_runtime_tracker.record_cycle(result) + + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.decode_legacy_shell.take_pending_snapshot(wait_for_batch) + + def _take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.decode_legacy_shell.take_pending_snapshot_nonblocking(wait_for_batch) + + def has_decode_runtime_work(self) -> bool: + return self.decode_legacy_shell.has_decode_runtime_work() + + def execute_prefill_merge( + self, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + return self.decode_executor.execute_prefill_merge( + pending_jobs=pending_jobs, + active_batch=active_batch, + mark_prefill_started=self._mark_prefill_started, + add_prefill_time=None if external_bookkeeping else self._add_prefill_time, + add_merge_time=None if external_bookkeeping else self._add_merge_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + + def execute_decode_step( + self, + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + return self.decode_executor.execute_decode_step( + active_batch=active_batch, + add_decode_time=None if external_bookkeeping else self._add_decode_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + + def execute_decode_cycle( + self, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_executor.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=active_batch, + mark_prefill_started=self._mark_prefill_started, + add_prefill_time=None if external_bookkeeping else self._add_prefill_time, + add_merge_time=None if external_bookkeeping else self._add_merge_time, + add_decode_time=None if external_bookkeeping else self._add_decode_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + self._record_decode_runtime_cycle(result) + return result + + def run_prefill_merge_once_nonblocking( + self, + external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_prefill_merge_once_nonblocking( + external_pending_jobs=external_pending_jobs, + external_active_batch=external_active_batch, + execute_prefill_merge=lambda batch_jobs, batch_state: self.execute_prefill_merge( + pending_jobs=batch_jobs, + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + ) + if emit_runtime_state: + self._notify_decode_runtime_state("prefill_merge") + return result + + def run_decode_step_once_nonblocking( + self, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_decode_step_once_nonblocking( + external_active_batch=external_active_batch, + execute_decode_step=lambda batch_state: self.execute_decode_step( + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + ) + if emit_runtime_state: + self._notify_decode_runtime_state("decode_step") + return result + + def run_decode_cycle_nonblocking( + self, + external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_decode_cycle_nonblocking( + external_pending_jobs=external_pending_jobs, + external_active_batch=external_active_batch, + execute_decode_cycle=lambda batch_jobs, batch_state: self.execute_decode_cycle( + pending_jobs=batch_jobs, + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + on_cycle_executed=None, + ) + if result.get("executed") and emit_runtime_state: + self._notify_decode_runtime_state("decode_cycle") + return result + + def execute_finalize_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + with self.condition: + for task in tasks: + job = self.job_registry.get(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return + now = time.perf_counter() + for task in tasks: + self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) + for job, item in jobs_and_items: + self._runtime_update( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.synthesize_finalize_jobs(jobs_and_items) + with self.condition: + for job, _ in jobs_and_items: + tracked_job = self.job_registry.get(job.request_id) + if tracked_job is not None: + tracked_job.synth_ms += synth_ms + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self._finalize_error([task.request_id for task in tasks], str(exc)) + finally: + self.finalize_executor.end_execution(len(tasks)) + + def _run_finalize_loop(self) -> None: + while True: + tasks = self.finalize_executor.take_task_batch_blocking() + self.execute_finalize_tasks(tasks) def _run_loop(self) -> None: - while True: - wait_for_batch = self.active_batch is None - pending_jobs = self._take_pending_snapshot(wait_for_batch=wait_for_batch) - - if pending_jobs: - try: - self._sync_device() - prefill_start = time.perf_counter() - self._mark_prefill_started(pending_jobs, prefill_start) - admitted_active_batch, admitted_finished = run_prefill_active_batch( - self.tts.t2s_model.model, - [job.state for job in pending_jobs], - max_steps=self.max_steps, - ) - self._sync_device() - self._add_prefill_time([job.request_id for job in pending_jobs], time.perf_counter() - prefill_start) - self._enqueue_finalize_finished(admitted_finished) - merge_start = time.perf_counter() - self.active_batch = merge_active_batches( - self.tts.t2s_model.model, - self.active_batch, - admitted_active_batch, - ) - self._add_merge_time( - [] if self.active_batch is None else list(self.active_batch.request_ids), - time.perf_counter() - merge_start, - ) - except Exception as exc: - self._finalize_error([job.request_id for job in pending_jobs], str(exc)) - - if self.active_batch is not None: - active_request_ids: List[str] = [] - try: - active_request_ids = [state.request_id for state in self.active_batch.states] - self._sync_device() - decode_start = time.perf_counter() - self.active_batch, step_finished = decode_one_step( - self.tts.t2s_model.model, - self.active_batch, - max_steps=self.max_steps, - ) - self._sync_device() - self._add_decode_time(active_request_ids, time.perf_counter() - decode_start) - self._enqueue_finalize_finished(step_finished) - except Exception as exc: - self._finalize_error(active_request_ids, str(exc)) - self.active_batch = None - continue - - if not pending_jobs: - time.sleep(self.micro_batch_wait_s) + self.decode_legacy_shell.run_loop( + run_decode_cycle_nonblocking=lambda: self.run_decode_cycle_nonblocking() + ) def set_scheduler_seed(seed: int): @@ -1088,6 +2751,27 @@ def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=3 class UnifiedTTSEngine: + @staticmethod + def _env_flag(name: str, default: bool) -> bool: + value = os.environ.get(name) + if value is None: + return bool(default) + return str(value).strip().lower() not in {"0", "false", "no", "off", ""} + + @staticmethod + def _env_int(name: str, default: int) -> int: + value = os.environ.get(name) + if value in [None, ""]: + return int(default) + return int(value) + + @staticmethod + def _env_float(name: str, default: float) -> float: + value = os.environ.get(name) + if value in [None, ""]: + return float(default) + return float(value) + def __init__( self, tts: TTS, @@ -1104,10 +2788,10 @@ class UnifiedTTSEngine: t2s_weights_path=str(self.tts.configs.t2s_weights_path), vits_weights_path=str(self.tts.configs.vits_weights_path), ) - self.request_registry_lock = threading.Lock() - self.active_requests: Dict[str, EngineRequestState] = {} - self.recent_requests: Deque[EngineRequestState] = deque() - self.recent_request_limit = max(1, int(os.environ.get("GPTSOVITS_ENGINE_RECENT_REQUEST_LIMIT", "64"))) + self.request_registry = EngineRequestRegistry( + recent_limit=max(1, int(os.environ.get("GPTSOVITS_ENGINE_RECENT_REQUEST_LIMIT", "64"))) + ) + self.engine_job_registry = SchedulerJobRegistry(threading.Lock()) self.scheduler_worker = UnifiedSchedulerWorker( tts, max_steps=max_steps, @@ -1116,10 +2800,72 @@ class UnifiedTTSEngine: update=self._update_request_state, complete=self._complete_request_state, fail=self._fail_request_state, + decode_runtime_update=self._update_engine_decode_runtime_state, ), + external_finalize_submit=self._enqueue_worker_finished_for_finalize, ) self.direct_tts_lock = threading.RLock() self.management_lock = threading.RLock() + worker_capacity_limits = self.scheduler_worker.get_capacity_limits() + prepare_max_inflight = int(self.scheduler_worker.get_prepare_max_inflight()) + self.engine_policy_config = EnginePolicyConfig( + enabled=self._env_flag("GPTSOVITS_ENGINE_POLICY_ENABLE", True), + poll_wait_ms=max(1.0, self._env_float("GPTSOVITS_ENGINE_POLICY_POLL_WAIT_MS", float(micro_batch_wait_ms))), + decode_backlog_soft_max=max( + 0, + self._env_int( + "GPTSOVITS_ENGINE_POLICY_DECODE_BACKLOG_SOFT_MAX", + int(worker_capacity_limits["decode_backlog_max"]), + ), + ), + finalize_pending_soft_max=max( + 0, + self._env_int( + "GPTSOVITS_ENGINE_POLICY_FINALIZE_PENDING_SOFT_MAX", + int(worker_capacity_limits["finalize_pending_max"]), + ), + ), + prepare_inflight_soft_max=max( + 0, + self._env_int("GPTSOVITS_ENGINE_POLICY_PREPARE_INFLIGHT_SOFT_MAX", prepare_max_inflight), + ), + active_decode_soft_max=max(0, self._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_DECODE_SOFT_MAX", 0)), + ready_for_prefill_soft_max=max(0, self._env_int("GPTSOVITS_ENGINE_POLICY_READY_FOR_PREFILL_SOFT_MAX", 0)), + active_request_soft_max=max(0, self._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_REQUEST_SOFT_MAX", 0)), + ) + self.engine_arbiter_config = EngineArbiterConfig( + poll_wait_ms=max(1.0, self._env_float("GPTSOVITS_ENGINE_ARBITER_POLL_WAIT_MS", float(micro_batch_wait_ms))), + decode_burst=max(1, self._env_int("GPTSOVITS_ENGINE_ARBITER_DECODE_BURST", 4)), + prepare_aging_ms=max(0.0, self._env_float("GPTSOVITS_ENGINE_ARBITER_PREPARE_AGING_MS", 10.0)), + finalize_aging_ms=max(0.0, self._env_float("GPTSOVITS_ENGINE_ARBITER_FINALIZE_AGING_MS", 10.0)), + ) + self.engine_decode_runtime_owner = EngineDecodeRuntimeOwner( + get_decode_runtime_counters=self.scheduler_worker.get_decode_runtime_counters, + get_micro_batch_wait_s=self.scheduler_worker.get_micro_batch_wait_s, + ) + self.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + self.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + self.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched") + self.engine_dispatch_last_snapshot: Dict[str, Any] = {} + self.engine_policy_arbiter = EnginePolicyArbiterController( + policy_config=self.engine_policy_config, + arbiter_config=self.engine_arbiter_config, + snapshot_request_registry=self._snapshot_request_registry, + get_worker_state=self.get_scheduler_state, + snapshot_prepare_state=self._snapshot_engine_prepare_state, + snapshot_finalize_state=self._snapshot_engine_finalize_state, + snapshot_dispatch_state=self._snapshot_engine_dispatch_state, + snapshot_decode_runtime_state=self._snapshot_engine_decode_runtime_state, + snapshot_job_registry=self._snapshot_engine_job_registry, + peek_queue_age_ms=self._peek_queue_age_ms, + merge_request_state_profile=self._merge_request_state_profile, + ) + self.engine_arbiter_thread = threading.Thread( + target=self._run_engine_arbiter_loop, + name="unified-engine-arbiter", + daemon=True, + ) + self.engine_arbiter_thread.start() def _register_request_state( self, @@ -1131,27 +2877,15 @@ class UnifiedTTSEngine: deadline_ts: float | None = None, meta: Optional[Dict[str, Any]] = None, ) -> EngineRequestState: - now = time.perf_counter() - state = EngineRequestState( + return self.request_registry.register( request_id=request_id, api_mode=api_mode, backend=backend, media_type=media_type, - response_streaming=bool(response_streaming), - submit_ts=now, + response_streaming=response_streaming, deadline_ts=deadline_ts, - updated_ts=now, - meta=dict(meta or {}), - lifecycle_timestamps={EngineStatus.NEW: now}, + meta=meta, ) - with self.request_registry_lock: - self.active_requests[request_id] = state - return state - - def _move_to_recent_locked(self, state: EngineRequestState) -> None: - self.recent_requests.appendleft(state) - while len(self.recent_requests) > self.recent_request_limit: - self.recent_requests.pop() def _update_request_state( self, @@ -1159,95 +2893,501 @@ class UnifiedTTSEngine: status: str, extra: Optional[Dict[str, Any]] = None, ) -> None: - now = time.perf_counter() - with self.request_registry_lock: - state = self.active_requests.get(request_id) - if state is None: - return - state.status = status - state.updated_ts = now - state.lifecycle_timestamps[status] = now - if extra: - backend = extra.pop("backend", None) - if backend is not None: - state.backend = str(backend) - finish_reason = extra.pop("finish_reason", None) - if finish_reason is not None: - state.finish_reason = str(finish_reason) - error = extra.pop("error", None) - if error is not None: - state.error = str(error) - state.profile.update(extra) + self.request_registry.update(request_id, status, extra) def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - if not extra: + self.request_registry.merge_profile(request_id, extra) + + def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: + return self.engine_prepare_queue_owner.snapshot(max_request_ids=16) + + def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: + return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) + + def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: + return self.engine_dispatch_queue_owner.snapshot( + max_request_ids=16, + extra={"last_policy_snapshot": dict(self.engine_dispatch_last_snapshot or {})}, + ) + + def _register_engine_job(self, job: SchedulerPendingJob) -> None: + self.engine_job_registry.register(job, keep_job=True) + + def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.engine_job_registry.get(request_id) + + def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.engine_job_registry.pop(request_id) + + def _snapshot_engine_job_registry(self) -> Dict[str, Any]: + return self.engine_job_registry.snapshot(max_request_ids=32) + + def _is_engine_drained(self) -> bool: + prepare_empty = self.engine_prepare_queue_owner.is_drained() + dispatch_empty = self.engine_dispatch_queue_owner.is_drained() + finalize_empty = self.engine_finalize_queue_owner.is_drained() + decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() + job_empty = self.engine_job_registry.is_empty() + worker_state = self.scheduler_worker.snapshot() + return bool( + prepare_empty + and dispatch_empty + and finalize_empty + and decode_pending_empty + and job_empty + and self.engine_decode_runtime_owner.get_active_batch() is None + and int(worker_state.get("prepare_inflight", 0)) <= 0 + and int(worker_state.get("finalize_inflight", 0)) <= 0 + and int(worker_state.get("finalize_pending", 0)) <= 0 + ) + + def _record_engine_job_done(self, request_id: str) -> None: + self.engine_job_registry.mark_finished_and_remove(request_id) + self.scheduler_worker.record_external_job_done(request_id) + + def _complete_engine_job( + self, + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + completion_bridge = self.scheduler_worker.completion_bridge + completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) + completion_bridge.complete_job( + job, + runtime_request_id=job.engine_request_id, + runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate), + on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid), + ) + + def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: + if not request_ids: return - now = time.perf_counter() - with self.request_registry_lock: - state = self.active_requests.get(request_id) - if state is None: - for recent_state in self.recent_requests: - if recent_state.request_id == request_id: - state = recent_state - break - if state is None: - return - state.updated_ts = now - backend = extra.get("backend") - if backend is not None: - state.backend = str(backend) - finish_reason = extra.get("finish_reason") - if finish_reason is not None: - state.finish_reason = str(finish_reason) - error = extra.get("error") - if error is not None: - state.error = str(error) - merged = dict(extra) - merged.pop("backend", None) - merged.pop("finish_reason", None) - merged.pop("error", None) - state.profile.update(merged) + completion_bridge = self.scheduler_worker.completion_bridge + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is None: + continue + completion_bridge.fail_job( + job, + error=error, + on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid), + ) + + def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + for job in jobs: + job.prefill_ms += delta_ms + + def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + activate_request_ids: List[str] = [] + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is None: + continue + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) + job.decode_ms += delta_ms + job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None) + + def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items] + self._enqueue_worker_finished_for_finalize(tasks) + + def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: + return self.engine_decode_runtime_owner.snapshot_pending_queue_state() + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch) + + def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: + self.engine_decode_runtime_owner.refresh_state(last_event) + + def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: + if not snapshot: + return + if self.scheduler_worker.is_engine_decode_control_enabled(): + return + self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot) + + def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: + return self.engine_decode_runtime_owner.snapshot_state() + + def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: + return self.engine_policy_arbiter.snapshot_state() + + def _notify_engine_arbiter(self) -> None: + self.engine_policy_arbiter.notify() + + def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: + self.engine_decode_runtime_owner.enqueue_pending_job(job) + self._notify_engine_arbiter() + + def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.engine_decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch) + + def _peek_queue_age_ms(self, queue_name: str) -> float: + if queue_name == "prepare": + return self.engine_prepare_queue_owner.peek_oldest_age_ms("enqueue_time") + elif queue_name == "finalize": + return self.engine_finalize_queue_owner.peek_oldest_age_ms("enqueued_time") + elif queue_name == "decode_runtime_pending": + return self.engine_decode_runtime_owner.pending_age_ms() + else: + return self.engine_dispatch_queue_owner.peek_oldest_age_ms("enqueue_time") + + def _engine_has_pending_work(self) -> bool: + if self.scheduler_worker.is_engine_decode_control_enabled(): + if self.engine_decode_runtime_owner.has_pending_jobs(): + return True + if self.scheduler_worker.is_engine_decode_control_enabled() and self._snapshot_engine_decode_runtime_state().get("active_request_count", 0) > 0: + return True + if self.engine_prepare_queue_owner.has_items(): + return True + if self.engine_finalize_queue_owner.has_items(): + return True + return self.engine_dispatch_queue_owner.has_items() + + @staticmethod + def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None: + if future.done(): + return + future.set_exception(error) + + def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + @staticmethod + def _resolve_prepare_future( + future: asyncio.Future, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if future.done(): + return + future.set_result(payload) + + def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + def _notify_prepare_result( + self, + task: EngineGpuPrepareTask, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload) + except RuntimeError: + pass + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + if engine_request_id not in [None, ""]: + self._update_request_state( + str(engine_request_id), + EngineStatus.GPU_PREPARING, + { + "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), + "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), + "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), + "text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms), + }, + ) + loop = asyncio.get_running_loop() + done_future = loop.create_future() + task = EngineGpuPrepareTask( + request_id=spec.request_id, + cpu_stage=cpu_stage, + done_loop=loop, + done_future=done_future, + engine_request_id=engine_request_id or spec.request_id, + enqueue_time=time.perf_counter(), + ) + self.engine_prepare_queue_owner.enqueue(task) + self._notify_engine_arbiter() + state, prepare_exec_started_at, prepare_exec_finished_at = await done_future + return state, prepare_exec_started_at, prepare_exec_finished_at + + def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + for task in tasks: + job = self._get_engine_job(task.request_id) + if job is not None: + self._update_request_state( + job.engine_request_id, + EngineStatus.READY_FOR_FINALIZE, + { + "finish_reason": task.item.finish_reason, + "semantic_len": int(task.item.semantic_tokens.shape[0]), + "finish_idx": int(task.item.finish_idx), + }, + ) + self.engine_finalize_queue_owner.enqueue_many(tasks) + self._notify_engine_arbiter() + + def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + finalize_policy = self.scheduler_worker.get_finalize_batch_policy() + return self.engine_finalize_queue_owner.take_finalize_batch( + finalize_mode=str(finalize_policy.get("finalize_mode", "async")), + batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)), + batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)), + use_vocoder=bool(self.tts.configs.use_vocoder), + ) + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + task = EngineDispatchTask( + request_id=state.request_id, + state=state, + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id or state.request_id, + timeout_sec=timeout_sec, + enqueue_time=time.perf_counter(), + ) + self.engine_dispatch_queue_owner.enqueue(task) + self._notify_engine_arbiter() + self._merge_request_state_profile( + task.engine_request_id or task.request_id, + { + "engine_dispatch_queue_depth_on_enqueue": int(self._snapshot_engine_dispatch_state()["waiting_count"]), + }, + ) + return task + + def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) + + def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage() + self.engine_dispatch_last_snapshot = dict(policy_snapshot) + return stage, reason, policy_snapshot, worker_state + + def _run_engine_prepare_once(self) -> bool: + task = self.engine_prepare_queue_owner.pop_left() + if task is None: + return False + queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) + try: + state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( + self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) + ) + state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) + if task.engine_request_id not in [None, ""]: + self._merge_request_state_profile( + str(task.engine_request_id), + {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, + ) + self.engine_prepare_queue_owner.mark_completed(1) + self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) + return True + except Exception as exc: + task.error = str(exc) + self._fail_request_state(task.engine_request_id or task.request_id, str(exc)) + self._notify_prepare_error(task, exc) + return True + + def _run_engine_finalize_once(self) -> bool: + tasks = self._take_engine_finalize_batch_nonblocking() + if not tasks: + return False + self.scheduler_worker.begin_finalize_execution(len(tasks)) + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + for task in tasks: + job = self._get_engine_job(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return False + now = time.perf_counter() + for task in tasks: + job = self._get_engine_job(task.request_id) + if job is not None: + job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) + for job, item in jobs_and_items: + self._update_request_state( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items) + for job, _ in jobs_and_items: + job.synth_ms += float(synth_ms) + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self._fail_engine_jobs([task.request_id for task in tasks], str(exc)) + finally: + self.scheduler_worker.end_finalize_execution(len(tasks)) + self.engine_finalize_queue_owner.mark_completed(len(tasks), notify=True) + return True + + def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + if not bool(policy_snapshot.get("allowed", True)): + return False + dispatch_task = self.engine_dispatch_queue_owner.pop_left() + if dispatch_task is None: + return False + dispatched_at = time.perf_counter() + dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0) + dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_policy_snapshot = dict(policy_snapshot) + try: + worker_job = self.scheduler_worker.submit( + state=dispatch_task.state, + speed_factor=dispatch_task.speed_factor, + sample_steps=dispatch_task.sample_steps, + media_type=dispatch_task.media_type, + prepare_wall_ms=dispatch_task.prepare_wall_ms, + prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, + done_loop=dispatch_task.done_loop, + done_future=dispatch_task.done_future, + engine_request_id=dispatch_task.engine_request_id, + timeout_sec=dispatch_task.timeout_sec, + skip_capacity_wait=True, + admission_wait_ms_override=0.0, + admission_snapshot_override=dict(worker_state), + engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms, + engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms, + enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(), + ) + dispatch_task.worker_job = worker_job + self._register_engine_job(worker_job) + if self.scheduler_worker.is_engine_decode_control_enabled(): + self._enqueue_engine_decode_pending_job(worker_job) + self.engine_dispatch_queue_owner.mark_completed(1) + return True + except Exception as exc: + dispatch_task.error = str(exc) + self._fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc)) + self._notify_dispatch_error(dispatch_task, exc) + return True + + def _run_engine_decode_runtime_once(self) -> bool: + if not self.scheduler_worker.is_engine_decode_control_enabled(): + return False + runtime_state = self._snapshot_engine_decode_runtime_state() + pending_jobs = self._take_engine_decode_pending_jobs_nonblocking( + wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0 + ) + result = self.scheduler_worker.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=self.engine_decode_runtime_owner.get_active_batch(), + external_bookkeeping=True, + ) + prefill_phase = dict(result.get("prefill_phase") or {}) + if prefill_phase.get("error"): + self._fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error"))) + else: + prefill_jobs = list(prefill_phase.get("pending_jobs") or []) + self._add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0))) + self._add_engine_merge_time( + [] if result.get("active_batch") is None else list(result["active_batch"].request_ids), + float(prefill_phase.get("merge_elapsed_s", 0.0)), + ) + self._enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or [])) + decode_phase = dict(result.get("decode_phase") or {}) + if decode_phase.get("error"): + self._fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error"))) + else: + self._add_engine_decode_time( + list(decode_phase.get("request_ids") or []), + float(decode_phase.get("decode_elapsed_s", 0.0)), + ) + self._enqueue_engine_finished_items(list(decode_phase.get("finished_items") or [])) + self.engine_decode_runtime_owner.set_active_batch(result.get("active_batch")) + if result.get("executed", False): + self._refresh_engine_decode_runtime_state("engine_decode_cycle") + return bool(result.get("executed", False)) + + def _run_engine_arbiter_loop(self) -> None: + while True: + if not self._engine_has_pending_work(): + self._mark_arbiter_tick(stage="idle", reason="no_pending_work", policy_allowed=True) + self.engine_policy_arbiter.wait() + continue + stage, reason, policy_snapshot, worker_state = self._select_engine_stage() + policy_allowed = bool(policy_snapshot.get("allowed", True)) + executed = False + if stage == "prepare": + executed = self._run_engine_prepare_once() + elif stage == "finalize": + executed = self._run_engine_finalize_once() + elif stage == "decode_dispatch": + executed = self._run_engine_dispatch_once(policy_snapshot, worker_state) + elif stage == "decode_runtime": + executed = self._run_engine_decode_runtime_once() + if not executed: + self._mark_arbiter_tick(stage="idle", reason=f"{stage}_not_ready", policy_allowed=policy_allowed) + self.engine_policy_arbiter.wait() + continue + self._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - now = time.perf_counter() - with self.request_registry_lock: - state = self.active_requests.pop(request_id, None) - if state is None: - return - state.status = EngineStatus.COMPLETED - state.updated_ts = now - state.lifecycle_timestamps[EngineStatus.COMPLETED] = now - if extra: - finish_reason = extra.pop("finish_reason", None) - if finish_reason is not None: - state.finish_reason = str(finish_reason) - state.profile.update(extra) - self._move_to_recent_locked(state) + self.request_registry.complete(request_id, extra) def _fail_request_state(self, request_id: str, error: str) -> None: - now = time.perf_counter() - with self.request_registry_lock: - state = self.active_requests.pop(request_id, None) - if state is None: - return - state.status = EngineStatus.FAILED - state.updated_ts = now - state.error = str(error) - state.lifecycle_timestamps[EngineStatus.FAILED] = now - self._move_to_recent_locked(state) + self.request_registry.fail(request_id, error) def _snapshot_request_registry(self) -> Dict[str, Any]: - with self.request_registry_lock: - active = [state.to_summary() for state in self.active_requests.values()] - recent = [state.to_summary() for state in list(self.recent_requests)] - active.sort(key=lambda item: item["submit_ts"]) - return { - "active_count": len(active), - "recent_count": len(recent), - "recent_limit": self.recent_request_limit, - "active_requests": active, - "recent_requests": recent, - } + return self.request_registry.snapshot() @staticmethod def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: @@ -1258,41 +3398,56 @@ class UnifiedTTSEngine: except Exception: return None + def _build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.engine_policy_arbiter.build_stage_counters(request_registry, worker_state) + + def _build_engine_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.engine_policy_arbiter.build_policy_snapshot(request_registry, worker_state) + + async def _wait_for_engine_policy_admission( + self, + *, + request_id: str | None, + timeout_sec: float | None, + ) -> tuple[float, Dict[str, Any]]: + return await self.engine_policy_arbiter.wait_for_policy_admission( + request_id=request_id, + timeout_sec=timeout_sec, + ) + def _build_stage_summary( self, request_registry: Dict[str, Any], worker_state: Dict[str, Any], ) -> Dict[str, Any]: - active_requests = list(request_registry.get("active_requests", [])) - status_counts: Dict[str, int] = {} - for item in active_requests: - status = str(item.get("status", "UNKNOWN")) - status_counts[status] = status_counts.get(status, 0) + 1 - + counters = self._build_stage_counters(request_registry, worker_state) bert_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_bert_batch_worker", None)) ref_semantic_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_ref_semantic_batch_worker", None)) text_preprocessor_state = self._safe_component_snapshot(getattr(self.tts, "text_preprocessor", None)) return { - "active_request_count": int(len(active_requests)), - "status_counts": status_counts, - "queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)), - "cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)), - "gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)), - "ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)), - "active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)), - "ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)), - "finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)), - "streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)), - "worker_pending_jobs": int(worker_state.get("pending_jobs", 0)), - "worker_decode_active_size": int(worker_state.get("running_requests", 0)), - "worker_prepare_inflight": int(worker_state.get("prepare_inflight", 0)), - "worker_finalize_pending": int(worker_state.get("finalize_pending", 0)), - "worker_finalize_inflight": int(worker_state.get("finalize_inflight", 0)), + **counters, + "engine_drained": bool(self._is_engine_drained()), "admission_config": { "decode_backlog_max": int(worker_state.get("decode_backlog_max", 0)), "finalize_pending_max": int(worker_state.get("finalize_pending_max", 0)), }, + "engine_policy": self._build_engine_policy_snapshot(request_registry, worker_state), + "engine_arbiter_state": self._snapshot_engine_arbiter_state(), + "engine_decode_runtime_state": self._snapshot_engine_decode_runtime_state(), + "engine_job_registry": self._snapshot_engine_job_registry(), + "engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(), + "engine_prepare_state": self._snapshot_engine_prepare_state(), + "engine_finalize_state": self._snapshot_engine_finalize_state(), + "engine_dispatcher_state": self._snapshot_engine_dispatch_state(), "active_batch": dict(worker_state.get("active_batch") or {}), "prepare_state": dict(worker_state.get("prepare_state") or {}), "bert_batch_worker_state": bert_worker_state, @@ -1301,21 +3456,10 @@ class UnifiedTTSEngine: } def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: - requested = set(request_ids) - results: List[Dict[str, Any]] = [] - with self.request_registry_lock: - for state in self.active_requests.values(): - if state.request_id in requested: - results.append(state.to_summary()) - for state in self.recent_requests: - if state.request_id in requested and all(item["request_id"] != state.request_id for item in results): - results.append(state.to_summary()) - results.sort(key=lambda item: item["request_id"]) - return results + return self.request_registry.collect_summaries(request_ids) def _has_active_request(self, request_id: str) -> bool: - with self.request_registry_lock: - return request_id in self.active_requests + return self.request_registry.has_active(request_id) @staticmethod def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: @@ -1356,6 +3500,11 @@ class UnifiedTTSEngine: "text_len": len(str(segment_text)), "prepare_wall_ms": float(prepare_item.get("prepare_wall_ms", 0.0)), "prepare_profile_total_ms": float(prepare_item.get("prepare_profile_total_ms", 0.0)), + "prepare_engine_gpu_queue_wait_ms": float( + dict(prepare_item.get("prepare_profile", {})).get("engine_gpu_prepare_queue_wait_ms", 0.0) + ), + "engine_policy_wait_ms": float(prepare_item.get("engine_policy_wait_ms", 0.0)), + "engine_dispatch_wait_ms": float(prepare_item.get("engine_dispatch_wait_ms", 0.0)), "decode_admission_wait_ms": float(worker_item.get("decode_admission_wait_ms", 0.0)), "queue_wait_ms": float(worker_item.get("queue_wait_ms", 0.0)), "prefill_ms": float(worker_item.get("prefill_ms", 0.0)), @@ -1391,6 +3540,8 @@ class UnifiedTTSEngine: request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) prepare_wall_ms = self._sum_profile_field(prepare_profiles, "prepare_wall_ms") prepare_profile_total_ms = self._sum_profile_field(prepare_profiles, "prepare_profile_total_ms") + engine_policy_wait_ms = self._sum_profile_field(prepare_profiles, "engine_policy_wait_ms") + engine_dispatch_wait_ms = self._sum_profile_field(prepare_profiles, "engine_dispatch_wait_ms") decode_admission_wait_ms = self._sum_profile_field(worker_profiles, "decode_admission_wait_ms") queue_wait_ms = self._sum_profile_field(worker_profiles, "queue_wait_ms") prefill_ms = self._sum_profile_field(worker_profiles, "prefill_ms") @@ -1403,7 +3554,7 @@ class UnifiedTTSEngine: semantic_len = sum(int(item.get("semantic_len", 0)) for item in worker_profiles) request_other_ms = max( 0.0, - request_total_ms - prepare_wall_ms - worker_total_ms - pack_ms - response_overhead_ms, + request_total_ms - prepare_wall_ms - engine_policy_wait_ms - worker_total_ms - pack_ms - response_overhead_ms, ) return { "backend": backend, @@ -1415,6 +3566,8 @@ class UnifiedTTSEngine: "prepare_ms": prepare_wall_ms, "prepare_wall_ms": prepare_wall_ms, "prepare_profile_total_ms": prepare_profile_total_ms, + "engine_policy_wait_ms": engine_policy_wait_ms, + "engine_dispatch_wait_ms": engine_dispatch_wait_ms, "decode_admission_wait_ms": decode_admission_wait_ms, "queue_wait_ms": queue_wait_ms, "prefill_ms": prefill_ms, @@ -1488,6 +3641,7 @@ class UnifiedTTSEngine: prepare_profile_total_ms: float, prepare_profile_wall_ms: float, prepare_other_ms: float, + engine_policy_wait_ms: float, api_after_prepare_ms: float, api_wait_result_ms: float, pack_ms: float, @@ -1498,7 +3652,13 @@ class UnifiedTTSEngine: request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) request_other_ms = max( 0.0, - request_total_ms - prepare_wall_ms - api_after_prepare_ms - worker_total_ms - api_wait_result_ms - pack_ms, + request_total_ms + - prepare_wall_ms + - engine_policy_wait_ms + - api_after_prepare_ms + - worker_total_ms + - api_wait_result_ms + - pack_ms, ) result = { "backend": backend, @@ -1513,6 +3673,7 @@ class UnifiedTTSEngine: "prepare_profile_total_ms": prepare_profile_total_ms, "prepare_profile_wall_ms": prepare_profile_wall_ms, "prepare_other_ms": prepare_other_ms, + "engine_policy_wait_ms": float(engine_policy_wait_ms), "api_after_prepare_ms": api_after_prepare_ms, "api_wait_result_ms": api_wait_result_ms, "pack_ms": pack_ms, @@ -1542,6 +3703,8 @@ class UnifiedTTSEngine: "X-Finish-Reason": str(profile.get("finish_reason", "unknown")), "X-Queue-Wait-Ms": self._format_ms_header(profile.get("queue_wait_ms", 0.0)), "X-Decode-Admission-Wait-Ms": self._format_ms_header(profile.get("decode_admission_wait_ms", 0.0)), + "X-Engine-Policy-Wait-Ms": self._format_ms_header(profile.get("engine_policy_wait_ms", 0.0)), + "X-Engine-Dispatch-Wait-Ms": self._format_ms_header(profile.get("engine_dispatch_wait_ms", 0.0)), "X-Prepare-Ms": self._format_ms_header(profile.get("prepare_wall_ms", 0.0)), "X-Prepare-Wall-Ms": self._format_ms_header(profile.get("prepare_wall_ms", 0.0)), "X-Prepare-Spec-Build-Ms": self._format_ms_header(profile.get("prepare_spec_build_ms", 0.0)), @@ -1598,6 +3761,7 @@ class UnifiedTTSEngine: "X-Prepare-Target-Bert-Batch-Window-Ms": self._format_ms_header(prepare_profile.get("text_bert_batch_window_ms", 0.0)), "X-Prepare-Text-Pair-Wall-Ms": self._format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)), "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), + "X-Prepare-Engine-GPU-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)), "X-Prepare-Audio-Load-Ms": self._format_ms_header(prepare_profile.get("audio_load_ms", 0.0)), "X-Prepare-Audio-Stage-Wait-Ms": self._format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)), "X-Prepare-Prompt-Semantic-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)), @@ -2000,14 +4164,22 @@ class UnifiedTTSEngine: prepared_items = await asyncio.gather( *[ - self.scheduler_worker.prepare_state_profiled_async(spec, time.perf_counter()) + self._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=time.perf_counter(), + engine_request_id=None, + ) for spec in segment_specs ] ) prepare_profiles: List[Dict[str, Any]] = [] - jobs: List[SchedulerPendingJob] = [] loop = asyncio.get_running_loop() done_futures: List[asyncio.Future] = [] + self._update_request_state( + request_id, + EngineStatus.READY_FOR_PREFILL, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_specs)}, + ) for spec, (state, prepare_exec_started_at, prepare_exec_finished_at) in zip(segment_specs, prepared_items): prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms)) @@ -2021,39 +4193,38 @@ class UnifiedTTSEngine: ) done_future = loop.create_future() done_futures.append(done_future) - jobs.append( - await self.scheduler_worker.submit_async( - state=state, - speed_factor=float(normalized.speed_factor), - sample_steps=int(normalized.sample_steps), - media_type=media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=loop, - done_future=done_future, - engine_request_id=None, - timeout_sec=normalized.timeout_sec, - ) + await self._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=float(normalized.speed_factor), + sample_steps=int(normalized.sample_steps), + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, + engine_request_id=None, + timeout_sec=normalized.timeout_sec, ) - self._update_request_state( - request_id, - EngineStatus.READY_FOR_PREFILL, - { - "backend": "scheduler_v1_direct", - "backend_mode": "scheduler_v1_direct", - "segment_count": len(segment_specs), - "prepare_aggregate": self._aggregate_numeric_dicts( - [item["prepare_profile"] for item in prepare_profiles] - ), - }, - ) self._update_request_state( request_id, EngineStatus.ACTIVE_DECODE, {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, ) timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0) - await asyncio.wait_for(asyncio.gather(*done_futures), timeout=timeout_sec) + jobs: List[SchedulerPendingJob] = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=timeout_sec)) + for profile_item, job in zip(prepare_profiles, jobs): + profile_item["engine_policy_wait_ms"] = float(job.engine_policy_wait_ms) + profile_item["engine_dispatch_wait_ms"] = float(job.engine_dispatch_wait_ms) + self._merge_request_state_profile( + request_id, + { + "engine_policy_wait_ms": sum(float(job.engine_policy_wait_ms) for job in jobs), + "engine_dispatch_wait_ms": sum(float(job.engine_dispatch_wait_ms) for job in jobs), + "prepare_aggregate": self._aggregate_numeric_dicts( + [item["prepare_profile"] for item in prepare_profiles] + ), + }, + ) sample_rate: int | None = None audio_parts: List[np.ndarray] = [] @@ -2470,9 +4641,10 @@ class UnifiedTTSEngine: prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) self._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, {"prepare_spec_build_ms": prepare_spec_build_ms}) try: - state, prepare_exec_started_at, prepare_exec_finished_at = await self.scheduler_worker.prepare_state_profiled_async( - spec, - spec_ready_at, + state, prepare_exec_started_at, prepare_exec_finished_at = await self._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=spec_ready_at, + engine_request_id=spec.request_id, ) except Exception as exc: self._fail_request_state(spec.request_id, str(exc)) @@ -2496,7 +4668,7 @@ class UnifiedTTSEngine: api_after_prepare_start = time.perf_counter() loop = asyncio.get_running_loop() done_future = loop.create_future() - job = await self.scheduler_worker.submit_async( + await self._enqueue_prepared_state_for_dispatch( state=state, speed_factor=float(normalized.speed_factor), sample_steps=int(normalized.sample_steps), @@ -2510,7 +4682,7 @@ class UnifiedTTSEngine: ) api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0) try: - await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)) + job = await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)) except Exception as exc: self._fail_request_state(spec.request_id, str(exc)) raise @@ -2542,6 +4714,7 @@ class UnifiedTTSEngine: prepare_profile_total_ms=prepare_profile_total_ms, prepare_profile_wall_ms=prepare_profile_wall_ms, prepare_other_ms=prepare_other_ms, + engine_policy_wait_ms=float(job.result.get("engine_policy_wait_ms", 0.0)), api_after_prepare_ms=api_after_prepare_ms, api_wait_result_ms=api_wait_result_ms, pack_ms=pack_ms, @@ -2568,6 +4741,14 @@ class UnifiedTTSEngine: default_ref = self.reference_registry.get_default() scheduler_state = self.get_scheduler_state() request_registry = self._snapshot_request_registry() + engine_policy = self._build_engine_policy_snapshot(request_registry, scheduler_state) + engine_arbiter_state = self._snapshot_engine_arbiter_state() + engine_decode_runtime_state = self._snapshot_engine_decode_runtime_state() + engine_job_registry = self._snapshot_engine_job_registry() + engine_prepare_state = self._snapshot_engine_prepare_state() + engine_finalize_state = self._snapshot_engine_finalize_state() + engine_dispatcher_state = self._snapshot_engine_dispatch_state() + engine_drained = self._is_engine_drained() return { "message": "success", "default_reference": { @@ -2583,6 +4764,15 @@ class UnifiedTTSEngine: "updated_at": model_state.updated_at, }, "worker_state": scheduler_state, + "engine_policy": engine_policy, + "engine_arbiter_state": engine_arbiter_state, + "engine_decode_runtime_state": engine_decode_runtime_state, + "engine_job_registry": engine_job_registry, + "engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(), + "engine_prepare_state": engine_prepare_state, + "engine_finalize_state": engine_finalize_state, + "engine_dispatcher_state": engine_dispatcher_state, + "engine_drained": bool(engine_drained), "request_registry": request_registry, "stage_summary": self._build_stage_summary(request_registry, scheduler_state), }