From b046a093d3f03b139d98e81dbb26a18e21cc5336 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Wed, 11 Mar 2026 18:35:47 +0800 Subject: [PATCH] Add unified engine delegates and orchestration components for enhanced TTS processing Introduce new modules including EngineApiDelegates, EngineBridgeDelegates, EngineRegistryBridgeFacade, EngineRuntimeBridgeFacade, EngineStageBridgeFacade, and EngineStageOrchestrator. These additions provide a structured approach to managing TTS requests, engine states, and orchestration, significantly improving the architecture and maintainability of the TTS system. The new components support asynchronous operations and enhance overall performance through better request handling and processing capabilities. --- .../unified_engine_api_delegates.py | 165 +++++++ .../TTS_infer_pack/unified_engine_bridge.py | 313 +------------ .../unified_engine_bridge_delegates.py | 200 +++++++++ .../unified_engine_bridge_registry.py | 193 +++++++++ .../unified_engine_bridge_runtime.py | 33 ++ .../unified_engine_bridge_stage.py | 114 +++++ .../unified_engine_delegates.py | 410 +----------------- .../unified_engine_orchestration.py | 92 ++++ .../unified_engine_runtime_delegates.py | 46 ++ .../TTS_infer_pack/unified_engine_stage.py | 381 +++------------- .../unified_engine_stage_executor.py | 358 +++++++++++++++ 11 files changed, 1280 insertions(+), 1025 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_api_delegates.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_runtime.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_runtime_delegates.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_delegates.py new file mode 100644 index 00000000..f42ec233 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_delegates.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple + +from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, NormalizedEngineRequest + + +class EngineApiDelegates: + def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: + return self.api_facade._collect_request_summaries(request_ids) + + def _has_active_request(self, request_id: str) -> bool: + return self.api_facade._has_active_request(request_id) + + @staticmethod + def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: + return EngineApiFacade._build_request_meta(payload) + + @staticmethod + def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: + return EngineApiFacade._sum_profile_field(items, key) + + def _build_direct_segment_trace( + self, + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + return self.api_facade._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) + + def _build_direct_scheduler_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_direct_scheduler_profile(**kwargs) + + def _build_legacy_direct_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_legacy_direct_profile(**kwargs) + + def _build_scheduler_submit_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_scheduler_submit_profile(**kwargs) + + @staticmethod + def _format_ms_header(value: Any) -> str: + return EngineApiFacade._format_ms_header(value) + + def _build_scheduler_submit_headers( + self, + *, + request_id: str, + media_type: str, + sample_rate: int, + profile: Dict[str, Any], + ) -> Dict[str, str]: + return self.api_facade._build_scheduler_submit_headers( + request_id=request_id, + media_type=media_type, + sample_rate=sample_rate, + profile=profile, + ) + + def _build_scheduler_debug_request_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_scheduler_debug_request_profile(**kwargs) + + @staticmethod + def _build_scheduler_debug_batch_profile(**kwargs: Any) -> Dict[str, Any]: + return EngineApiFacade._build_scheduler_debug_batch_profile(**kwargs) + + def _normalize_lang(self, value: str | None) -> str | None: + return self.api_facade._normalize_lang(value) + + @staticmethod + def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: + return EngineApiFacade._aggregate_numeric_dicts(items) + + def _apply_default_reference(self, req: dict) -> dict: + return self.api_facade._apply_default_reference(req) + + def check_params(self, req: dict) -> Optional[str]: + return self.api_facade.check_params(req) + + @staticmethod + def _base_request_defaults() -> Dict[str, Any]: + return EngineApiFacade._base_request_defaults() + + def _normalize_engine_request( + self, + payload: dict | NormalizedEngineRequest, + *, + request_id: str | None = None, + normalize_streaming: bool = False, + error_prefix: str = "request 参数非法: ", + ) -> NormalizedEngineRequest: + return self.api_facade._normalize_engine_request( + payload, + request_id=request_id, + normalize_streaming=normalize_streaming, + error_prefix=error_prefix, + ) + + @staticmethod + def _normalize_streaming_mode(req: dict) -> dict: + return EngineApiFacade._normalize_streaming_mode(req) + + @staticmethod + def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: + return EngineApiFacade._is_aux_ref_enabled(aux_ref_audio_paths) + + def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: + return self.api_facade._select_direct_backend(normalized) + + def _iter_legacy_direct_tts_bytes( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> Generator[bytes, None, None]: + return self.api_facade._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: + return self.api_facade._should_use_scheduler_backend_for_direct(req) + + def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: + return self.api_facade._segment_direct_text(normalized) + + def _build_segment_request( + self, + normalized: NormalizedEngineRequest, + *, + request_id: str, + text: str, + ) -> NormalizedEngineRequest: + return self.api_facade._build_segment_request(normalized, request_id=request_id, text=text) + + async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: + return await self.api_facade._run_direct_tts_via_scheduler(normalized) + + def _run_legacy_direct_tts_blocking( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + return self.api_facade._run_legacy_direct_tts_blocking( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + async def _run_direct_tts_via_legacy_backend( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + return await self.api_facade._run_direct_tts_via_legacy_backend( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py index 536efbc5..d7740a52 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py @@ -1,310 +1,21 @@ from __future__ import annotations -import asyncio -import time -from typing import Any, Dict, List, Optional +from typing import Any -import numpy as np - -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState -from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner, EngineDispatchTask, EngineRequestState, EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_registry import EngineRegistryBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_runtime import EngineRuntimeBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_stage import EngineStageBridgeFacade class EngineBridgeFacade: def __init__(self, owner: Any) -> None: self.owner = owner + self.registry_bridge = EngineRegistryBridgeFacade(owner) + self.stage_bridge = EngineStageBridgeFacade(owner) + self.runtime_bridge = EngineRuntimeBridgeFacade(owner) - @property - def request_registry(self): - return self.owner.request_registry - - @property - def engine_prepare_queue_owner(self): - return self.owner.engine_prepare_queue_owner - - @property - def engine_finalize_queue_owner(self): - return self.owner.engine_finalize_queue_owner - - @property - def engine_dispatch_queue_owner(self): - return self.owner.engine_dispatch_queue_owner - - @property - def engine_decode_runtime_owner(self): - return self.owner.engine_decode_runtime_owner - - @property - def engine_job_registry(self): - return self.owner.engine_job_registry - - @property - def scheduler_worker(self): - return self.owner.scheduler_worker - - @property - def engine_stage_coordinator(self): - return self.owner.engine_stage_coordinator - - @property - def engine_policy_arbiter(self): - return self.owner.engine_policy_arbiter - - def _register_request_state( - self, - request_id: str, - api_mode: str, - backend: str, - media_type: str, - response_streaming: bool, - deadline_ts: float | None = None, - meta: Optional[Dict[str, Any]] = None, - ) -> EngineRequestState: - return self.request_registry.register( - request_id=request_id, - api_mode=api_mode, - backend=backend, - media_type=media_type, - response_streaming=response_streaming, - deadline_ts=deadline_ts, - meta=meta, - ) - - def _update_request_state( - self, - request_id: str, - status: str, - extra: Optional[Dict[str, Any]] = None, - ) -> None: - self.request_registry.update(request_id, status, extra) - - def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.request_registry.merge_profile(request_id, extra) - - def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.request_registry.complete(request_id, extra) - - def _fail_request_state(self, request_id: str, error: str) -> None: - self.request_registry.fail(request_id, error) - - def _snapshot_request_registry(self) -> Dict[str, Any]: - return self.request_registry.snapshot() - - def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: - return self.engine_prepare_queue_owner.snapshot(max_request_ids=16) - - def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: - return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) - - def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: - return self.engine_dispatch_queue_owner.snapshot( - max_request_ids=16, - extra={"last_policy_snapshot": dict(self.owner.engine_dispatch_last_snapshot or {})}, - ) - - def _register_engine_job(self, job: SchedulerPendingJob) -> None: - self.engine_job_registry.register(job, keep_job=True) - - def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: - return self.engine_job_registry.get(request_id) - - def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: - return self.engine_job_registry.pop(request_id) - - def _snapshot_engine_job_registry(self) -> Dict[str, Any]: - return self.engine_job_registry.snapshot(max_request_ids=32) - - def _is_engine_drained(self) -> bool: - prepare_empty = self.engine_prepare_queue_owner.is_drained() - dispatch_empty = self.engine_dispatch_queue_owner.is_drained() - finalize_empty = self.engine_finalize_queue_owner.is_drained() - decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() - job_empty = self.engine_job_registry.is_empty() - worker_state = self.scheduler_worker.snapshot() - return bool( - prepare_empty - and dispatch_empty - and finalize_empty - and decode_pending_empty - and job_empty - and self.engine_decode_runtime_owner.get_active_batch() is None - and int(worker_state.get("prepare_inflight", 0)) <= 0 - and int(worker_state.get("finalize_inflight", 0)) <= 0 - and int(worker_state.get("finalize_pending", 0)) <= 0 - ) - - def _record_engine_job_done(self, request_id: str) -> None: - self.engine_job_registry.mark_finished_and_remove(request_id) - self.scheduler_worker.record_external_job_done(request_id) - - def _complete_engine_job( - self, - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - audio_data: np.ndarray, - ) -> None: - completion_bridge = self.scheduler_worker.completion_bridge - completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) - completion_bridge.complete_job( - job, - runtime_request_id=job.engine_request_id, - runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate), - on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid), - ) - - def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: - if not request_ids: - return - completion_bridge = self.scheduler_worker.completion_bridge - for request_id in request_ids: - job = self._get_engine_job(request_id) - if job is None: - continue - completion_bridge.fail_job( - job, - error=error, - on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid), - ) - - def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - for job in jobs: - job.prefill_ms += delta_ms - - def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - for request_id in request_ids: - job = self._get_engine_job(request_id) - if job is not None: - job.merge_ms += delta_ms - - def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - activate_request_ids: List[str] = [] - for request_id in request_ids: - job = self._get_engine_job(request_id) - if job is None: - continue - if job.decode_steps == 0: - activate_request_ids.append(job.engine_request_id) - job.decode_ms += delta_ms - job.decode_steps += 1 - for engine_request_id in activate_request_ids: - self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None) - - def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: - if not items: - return - enqueued_at = time.perf_counter() - tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items] - self._enqueue_worker_finished_for_finalize(tasks) - - def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: - return self.engine_decode_runtime_owner.snapshot_pending_queue_state() - - @staticmethod - def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: - return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch) - - def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: - self.engine_decode_runtime_owner.refresh_state(last_event) - - def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: - if not snapshot: - return - if self.scheduler_worker.is_engine_decode_control_enabled(): - return - self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot) - - def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: - return self.engine_decode_runtime_owner.snapshot_state() - - def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: - return self.engine_policy_arbiter.snapshot_state() - - def _notify_engine_arbiter(self) -> None: - self.engine_policy_arbiter.notify() - - def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: - self.engine_stage_coordinator.decode_runtime_owner.enqueue_pending_job(job) - self._notify_engine_arbiter() - - def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.engine_stage_coordinator.decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch) - - def _peek_queue_age_ms(self, queue_name: str) -> float: - return self.engine_stage_coordinator.peek_queue_age_ms(queue_name) - - def _engine_has_pending_work(self) -> bool: - return self.engine_stage_coordinator.has_pending_work() - - async def _prepare_state_via_engine_gpu_queue( - self, - *, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - engine_request_id: str | None, - ) -> tuple[T2SRequestState, float, float]: - return await self.engine_stage_coordinator.prepare_state_via_engine_gpu_queue( - spec=spec, - prepare_submit_at=prepare_submit_at, - engine_request_id=engine_request_id, - ) - - def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: - self.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks) - - def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: - return self.engine_stage_coordinator.take_engine_finalize_batch_nonblocking() - - async def _enqueue_prepared_state_for_dispatch( - self, - *, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None, - done_future: asyncio.Future | None, - engine_request_id: str | None, - timeout_sec: float | None, - ) -> EngineDispatchTask: - return await self.engine_stage_coordinator.enqueue_prepared_state_for_dispatch( - state=state, - speed_factor=speed_factor, - sample_steps=sample_steps, - media_type=media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=done_loop, - done_future=done_future, - engine_request_id=engine_request_id, - timeout_sec=timeout_sec, - ) - - def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: - self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) - - def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: - stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage() - self.owner.engine_dispatch_last_snapshot = dict(policy_snapshot) - return stage, reason, policy_snapshot, worker_state - - def _run_engine_prepare_once(self) -> bool: - return self.engine_stage_coordinator.run_engine_prepare_once() - - def _run_engine_finalize_once(self) -> bool: - return self.engine_stage_coordinator.run_engine_finalize_once() - - def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: - return self.engine_stage_coordinator.run_engine_dispatch_once(policy_snapshot, worker_state) - - def _run_engine_decode_runtime_once(self) -> bool: - return self.engine_stage_coordinator.run_engine_decode_runtime_once() - - def _run_engine_arbiter_loop(self) -> None: - self.engine_stage_coordinator.run_engine_arbiter_loop() + def __getattr__(self, name: str) -> Any: + for component in (self.registry_bridge, self.stage_bridge, self.runtime_bridge): + if hasattr(component, name): + return getattr(component, name) + raise AttributeError(name) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py new file mode 100644 index 00000000..92714750 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import asyncio +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, EngineRequestState, SchedulerFinalizeTask, SchedulerPendingJob + + +class EngineBridgeDelegates: + def _register_request_state( + self, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + return self.bridge_facade._register_request_state( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=response_streaming, + deadline_ts=deadline_ts, + meta=meta, + ) + + def _update_request_state(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._update_request_state(request_id, status, extra) + + def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._merge_request_state_profile(request_id, extra) + + def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_prepare_state() + + def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_finalize_state() + + def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_dispatch_state() + + def _register_engine_job(self, job: SchedulerPendingJob) -> None: + self.bridge_facade._register_engine_job(job) + + def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.bridge_facade._get_engine_job(request_id) + + def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.bridge_facade._pop_engine_job(request_id) + + def _snapshot_engine_job_registry(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_job_registry() + + def _is_engine_drained(self) -> bool: + return self.bridge_facade._is_engine_drained() + + def _record_engine_job_done(self, request_id: str) -> None: + self.bridge_facade._record_engine_job_done(request_id) + + def _complete_engine_job( + self, + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + self.bridge_facade._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) + + def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: + self.bridge_facade._fail_engine_jobs(request_ids, error) + + def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + self.bridge_facade._add_engine_prefill_time(jobs, elapsed_s) + + def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + self.bridge_facade._add_engine_merge_time(request_ids, elapsed_s) + + def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + self.bridge_facade._add_engine_decode_time(request_ids, elapsed_s) + + def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: + self.bridge_facade._enqueue_engine_finished_items(items) + + def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_decode_pending_queue_state() + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + return EngineBridgeFacade._summarize_active_batch(active_batch) + + def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: + self.bridge_facade._refresh_engine_decode_runtime_state(last_event) + + def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: + self.bridge_facade._update_engine_decode_runtime_state(snapshot) + + def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_decode_runtime_state() + + def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_arbiter_state() + + def _notify_engine_arbiter(self) -> None: + self.bridge_facade._notify_engine_arbiter() + + def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: + self.bridge_facade._enqueue_engine_decode_pending_job(job) + + def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.bridge_facade._take_engine_decode_pending_jobs_nonblocking(wait_for_batch) + + def _peek_queue_age_ms(self, queue_name: str) -> float: + return self.bridge_facade._peek_queue_age_ms(queue_name) + + def _engine_has_pending_work(self) -> bool: + return self.bridge_facade._engine_has_pending_work() + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + return await self.bridge_facade._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, + ) + + def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + self.bridge_facade._enqueue_worker_finished_for_finalize(tasks) + + def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + return self.bridge_facade._take_engine_finalize_batch_nonblocking() + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + return await self.bridge_facade._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=speed_factor, + sample_steps=sample_steps, + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id, + timeout_sec=timeout_sec, + ) + + def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + self.bridge_facade._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) + + def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + return self.bridge_facade._select_engine_stage() + + def _run_engine_prepare_once(self) -> bool: + return self.bridge_facade._run_engine_prepare_once() + + def _run_engine_finalize_once(self) -> bool: + return self.bridge_facade._run_engine_finalize_once() + + def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + return self.bridge_facade._run_engine_dispatch_once(policy_snapshot, worker_state) + + def _run_engine_decode_runtime_once(self) -> bool: + return self.bridge_facade._run_engine_decode_runtime_once() + + def _run_engine_arbiter_loop(self) -> None: + self.bridge_facade._run_engine_arbiter_loop() + + def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._complete_request_state(request_id, extra) + + def _fail_request_state(self, request_id: str, error: str) -> None: + self.bridge_facade._fail_request_state(request_id, error) + + def _snapshot_request_registry(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_request_registry() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py new file mode 100644 index 00000000..88b8cc5d --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +import time +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineRequestState, EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob + + +class EngineRegistryBridgeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def request_registry(self): + return self.owner.request_registry + + @property + def engine_prepare_queue_owner(self): + return self.owner.engine_prepare_queue_owner + + @property + def engine_finalize_queue_owner(self): + return self.owner.engine_finalize_queue_owner + + @property + def engine_dispatch_queue_owner(self): + return self.owner.engine_dispatch_queue_owner + + @property + def engine_decode_runtime_owner(self): + return self.owner.engine_decode_runtime_owner + + @property + def engine_job_registry(self): + return self.owner.engine_job_registry + + @property + def scheduler_worker(self): + return self.owner.scheduler_worker + + def _register_request_state( + self, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + return self.request_registry.register( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=response_streaming, + deadline_ts=deadline_ts, + meta=meta, + ) + + def _update_request_state( + self, + request_id: str, + status: str, + extra: Optional[Dict[str, Any]] = None, + ) -> None: + self.request_registry.update(request_id, status, extra) + + def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.request_registry.merge_profile(request_id, extra) + + def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.request_registry.complete(request_id, extra) + + def _fail_request_state(self, request_id: str, error: str) -> None: + self.request_registry.fail(request_id, error) + + def _snapshot_request_registry(self) -> Dict[str, Any]: + return self.request_registry.snapshot() + + def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: + return self.engine_prepare_queue_owner.snapshot(max_request_ids=16) + + def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: + return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) + + def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: + return self.engine_dispatch_queue_owner.snapshot( + max_request_ids=16, + extra={"last_policy_snapshot": dict(self.owner.engine_dispatch_last_snapshot or {})}, + ) + + def _register_engine_job(self, job: SchedulerPendingJob) -> None: + self.engine_job_registry.register(job, keep_job=True) + + def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.engine_job_registry.get(request_id) + + def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.engine_job_registry.pop(request_id) + + def _snapshot_engine_job_registry(self) -> Dict[str, Any]: + return self.engine_job_registry.snapshot(max_request_ids=32) + + def _is_engine_drained(self) -> bool: + prepare_empty = self.engine_prepare_queue_owner.is_drained() + dispatch_empty = self.engine_dispatch_queue_owner.is_drained() + finalize_empty = self.engine_finalize_queue_owner.is_drained() + decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() + job_empty = self.engine_job_registry.is_empty() + worker_state = self.scheduler_worker.snapshot() + return bool( + prepare_empty + and dispatch_empty + and finalize_empty + and decode_pending_empty + and job_empty + and self.engine_decode_runtime_owner.get_active_batch() is None + and int(worker_state.get("prepare_inflight", 0)) <= 0 + and int(worker_state.get("finalize_inflight", 0)) <= 0 + and int(worker_state.get("finalize_pending", 0)) <= 0 + ) + + def _record_engine_job_done(self, request_id: str) -> None: + self.engine_job_registry.mark_finished_and_remove(request_id) + self.scheduler_worker.record_external_job_done(request_id) + + def _complete_engine_job( + self, + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + completion_bridge = self.scheduler_worker.completion_bridge + completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) + completion_bridge.complete_job( + job, + runtime_request_id=job.engine_request_id, + runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate), + on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid), + ) + + def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: + if not request_ids: + return + completion_bridge = self.scheduler_worker.completion_bridge + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is None: + continue + completion_bridge.fail_job( + job, + error=error, + on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid), + ) + + def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + for job in jobs: + job.prefill_ms += delta_ms + + def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + activate_request_ids: List[str] = [] + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is None: + continue + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) + job.decode_ms += delta_ms + job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None) + + def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items] + self.owner.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_runtime.py new file mode 100644 index 00000000..47be8b67 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_runtime.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import Any, Dict + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner + + +class EngineRuntimeBridgeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def engine_policy_arbiter(self): + return self.owner.engine_policy_arbiter + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch) + + def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: + return self.engine_policy_arbiter.snapshot_state() + + def _notify_engine_arbiter(self) -> None: + self.engine_policy_arbiter.notify() + + def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) + + def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage() + self.owner.engine_dispatch_last_snapshot = dict(policy_snapshot) + return stage, reason, policy_snapshot, worker_state diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py new file mode 100644 index 00000000..29b5aaab --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import asyncio +from typing import Any, Dict, List + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, SchedulerFinalizeTask, SchedulerPendingJob + + +class EngineStageBridgeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def engine_decode_runtime_owner(self): + return self.owner.engine_decode_runtime_owner + + @property + def scheduler_worker(self): + return self.owner.scheduler_worker + + @property + def engine_stage_coordinator(self): + return self.owner.engine_stage_coordinator + + def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: + return self.engine_decode_runtime_owner.snapshot_pending_queue_state() + + def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: + self.engine_decode_runtime_owner.refresh_state(last_event) + + def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: + if not snapshot: + return + if self.scheduler_worker.is_engine_decode_control_enabled(): + return + self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot) + + def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: + return self.engine_decode_runtime_owner.snapshot_state() + + def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: + self.engine_decode_runtime_owner.enqueue_pending_job(job) + self.owner.engine_policy_arbiter.notify() + + def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.engine_decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch) + + def _peek_queue_age_ms(self, queue_name: str) -> float: + return self.engine_stage_coordinator.peek_queue_age_ms(queue_name) + + def _engine_has_pending_work(self) -> bool: + return self.engine_stage_coordinator.has_pending_work() + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + return await self.engine_stage_coordinator.prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, + ) + + def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + self.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks) + + def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + return self.engine_stage_coordinator.take_engine_finalize_batch_nonblocking() + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + return await self.engine_stage_coordinator.enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=speed_factor, + sample_steps=sample_steps, + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id, + timeout_sec=timeout_sec, + ) + + def _run_engine_prepare_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_prepare_once() + + def _run_engine_finalize_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_finalize_once() + + def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + return self.engine_stage_coordinator.run_engine_dispatch_once(policy_snapshot, worker_state) + + def _run_engine_decode_runtime_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_decode_runtime_once() + + def _run_engine_arbiter_loop(self) -> None: + return self.engine_stage_coordinator.run_engine_arbiter_loop() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py index f68d3ede..d60a3bb8 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py @@ -1,401 +1,9 @@ -from __future__ import annotations - -import asyncio -from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple - -import numpy as np - -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState -from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade -from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade -from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, EngineDispatchTask, EngineRequestState, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerFinalizeTask, SchedulerPendingJob, SchedulerSubmitExecution -from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade - - -class EngineBridgeDelegates: - def _register_request_state( - self, - request_id: str, - api_mode: str, - backend: str, - media_type: str, - response_streaming: bool, - deadline_ts: float | None = None, - meta: Optional[Dict[str, Any]] = None, - ) -> EngineRequestState: - return self.bridge_facade._register_request_state( - request_id=request_id, - api_mode=api_mode, - backend=backend, - media_type=media_type, - response_streaming=response_streaming, - deadline_ts=deadline_ts, - meta=meta, - ) - - def _update_request_state(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.bridge_facade._update_request_state(request_id, status, extra) - - def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.bridge_facade._merge_request_state_profile(request_id, extra) - - def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_prepare_state() - - def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_finalize_state() - - def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_dispatch_state() - - def _register_engine_job(self, job: SchedulerPendingJob) -> None: - self.bridge_facade._register_engine_job(job) - - def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: - return self.bridge_facade._get_engine_job(request_id) - - def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: - return self.bridge_facade._pop_engine_job(request_id) - - def _snapshot_engine_job_registry(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_job_registry() - - def _is_engine_drained(self) -> bool: - return self.bridge_facade._is_engine_drained() - - def _record_engine_job_done(self, request_id: str) -> None: - self.bridge_facade._record_engine_job_done(request_id) - - def _complete_engine_job( - self, - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - audio_data: np.ndarray, - ) -> None: - self.bridge_facade._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) - - def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: - self.bridge_facade._fail_engine_jobs(request_ids, error) - - def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: - self.bridge_facade._add_engine_prefill_time(jobs, elapsed_s) - - def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: - self.bridge_facade._add_engine_merge_time(request_ids, elapsed_s) - - def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: - self.bridge_facade._add_engine_decode_time(request_ids, elapsed_s) - - def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: - self.bridge_facade._enqueue_engine_finished_items(items) - - def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_decode_pending_queue_state() - - @staticmethod - def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: - return EngineBridgeFacade._summarize_active_batch(active_batch) - - def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: - self.bridge_facade._refresh_engine_decode_runtime_state(last_event) - - def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: - self.bridge_facade._update_engine_decode_runtime_state(snapshot) - - def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_decode_runtime_state() - - def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_arbiter_state() - - def _notify_engine_arbiter(self) -> None: - self.bridge_facade._notify_engine_arbiter() - - def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: - self.bridge_facade._enqueue_engine_decode_pending_job(job) - - def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.bridge_facade._take_engine_decode_pending_jobs_nonblocking(wait_for_batch) - - def _peek_queue_age_ms(self, queue_name: str) -> float: - return self.bridge_facade._peek_queue_age_ms(queue_name) - - def _engine_has_pending_work(self) -> bool: - return self.bridge_facade._engine_has_pending_work() - - async def _prepare_state_via_engine_gpu_queue( - self, - *, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - engine_request_id: str | None, - ) -> tuple[T2SRequestState, float, float]: - return await self.bridge_facade._prepare_state_via_engine_gpu_queue( - spec=spec, - prepare_submit_at=prepare_submit_at, - engine_request_id=engine_request_id, - ) - - def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: - self.bridge_facade._enqueue_worker_finished_for_finalize(tasks) - - def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: - return self.bridge_facade._take_engine_finalize_batch_nonblocking() - - async def _enqueue_prepared_state_for_dispatch( - self, - *, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None, - done_future: asyncio.Future | None, - engine_request_id: str | None, - timeout_sec: float | None, - ) -> EngineDispatchTask: - return await self.bridge_facade._enqueue_prepared_state_for_dispatch( - state=state, - speed_factor=speed_factor, - sample_steps=sample_steps, - media_type=media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=done_loop, - done_future=done_future, - engine_request_id=engine_request_id, - timeout_sec=timeout_sec, - ) - - def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: - self.bridge_facade._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) - - def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: - return self.bridge_facade._select_engine_stage() - - def _run_engine_prepare_once(self) -> bool: - return self.bridge_facade._run_engine_prepare_once() - - def _run_engine_finalize_once(self) -> bool: - return self.bridge_facade._run_engine_finalize_once() - - def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: - return self.bridge_facade._run_engine_dispatch_once(policy_snapshot, worker_state) - - def _run_engine_decode_runtime_once(self) -> bool: - return self.bridge_facade._run_engine_decode_runtime_once() - - def _run_engine_arbiter_loop(self) -> None: - self.bridge_facade._run_engine_arbiter_loop() - - def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.bridge_facade._complete_request_state(request_id, extra) - - def _fail_request_state(self, request_id: str, error: str) -> None: - self.bridge_facade._fail_request_state(request_id, error) - - def _snapshot_request_registry(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_request_registry() - - -class EngineApiDelegates: - def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: - return self.api_facade._collect_request_summaries(request_ids) - - def _has_active_request(self, request_id: str) -> bool: - return self.api_facade._has_active_request(request_id) - - @staticmethod - def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: - return EngineApiFacade._build_request_meta(payload) - - @staticmethod - def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: - return EngineApiFacade._sum_profile_field(items, key) - - def _build_direct_segment_trace( - self, - segment_texts: Sequence[str], - prepare_profiles: Sequence[Dict[str, Any]], - worker_profiles: Sequence[Dict[str, Any]], - ) -> List[Dict[str, Any]]: - return self.api_facade._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) - - def _build_direct_scheduler_profile(self, **kwargs: Any) -> Dict[str, Any]: - return self.api_facade._build_direct_scheduler_profile(**kwargs) - - def _build_legacy_direct_profile(self, **kwargs: Any) -> Dict[str, Any]: - return self.api_facade._build_legacy_direct_profile(**kwargs) - - def _build_scheduler_submit_profile(self, **kwargs: Any) -> Dict[str, Any]: - return self.api_facade._build_scheduler_submit_profile(**kwargs) - - @staticmethod - def _format_ms_header(value: Any) -> str: - return EngineApiFacade._format_ms_header(value) - - def _build_scheduler_submit_headers( - self, - *, - request_id: str, - media_type: str, - sample_rate: int, - profile: Dict[str, Any], - ) -> Dict[str, str]: - return self.api_facade._build_scheduler_submit_headers( - request_id=request_id, - media_type=media_type, - sample_rate=sample_rate, - profile=profile, - ) - - def _build_scheduler_debug_request_profile(self, **kwargs: Any) -> Dict[str, Any]: - return self.api_facade._build_scheduler_debug_request_profile(**kwargs) - - @staticmethod - def _build_scheduler_debug_batch_profile(**kwargs: Any) -> Dict[str, Any]: - return EngineApiFacade._build_scheduler_debug_batch_profile(**kwargs) - - def _normalize_lang(self, value: str | None) -> str | None: - return self.api_facade._normalize_lang(value) - - @staticmethod - def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: - return EngineApiFacade._aggregate_numeric_dicts(items) - - def _apply_default_reference(self, req: dict) -> dict: - return self.api_facade._apply_default_reference(req) - - def check_params(self, req: dict) -> Optional[str]: - return self.api_facade.check_params(req) - - @staticmethod - def _base_request_defaults() -> Dict[str, Any]: - return EngineApiFacade._base_request_defaults() - - def _normalize_engine_request( - self, - payload: dict | NormalizedEngineRequest, - *, - request_id: str | None = None, - normalize_streaming: bool = False, - error_prefix: str = "request 参数非法: ", - ) -> NormalizedEngineRequest: - return self.api_facade._normalize_engine_request( - payload, - request_id=request_id, - normalize_streaming=normalize_streaming, - error_prefix=error_prefix, - ) - - @staticmethod - def _normalize_streaming_mode(req: dict) -> dict: - return EngineApiFacade._normalize_streaming_mode(req) - - @staticmethod - def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: - return EngineApiFacade._is_aux_ref_enabled(aux_ref_audio_paths) - - def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: - return self.api_facade._select_direct_backend(normalized) - - def _iter_legacy_direct_tts_bytes( - self, - normalized: NormalizedEngineRequest, - *, - backend: str, - fallback_reason: str | None, - ) -> Generator[bytes, None, None]: - return self.api_facade._iter_legacy_direct_tts_bytes( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) - - def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: - return self.api_facade._should_use_scheduler_backend_for_direct(req) - - def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: - return self.api_facade._segment_direct_text(normalized) - - def _build_segment_request( - self, - normalized: NormalizedEngineRequest, - *, - request_id: str, - text: str, - ) -> NormalizedEngineRequest: - return self.api_facade._build_segment_request(normalized, request_id=request_id, text=text) - - async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: - return await self.api_facade._run_direct_tts_via_scheduler(normalized) - - def _run_legacy_direct_tts_blocking( - self, - normalized: NormalizedEngineRequest, - *, - backend: str, - fallback_reason: str | None, - ) -> DirectTTSExecution: - return self.api_facade._run_legacy_direct_tts_blocking( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) - - async def _run_direct_tts_via_legacy_backend( - self, - normalized: NormalizedEngineRequest, - *, - backend: str, - fallback_reason: str | None, - ) -> DirectTTSExecution: - return await self.api_facade._run_direct_tts_via_legacy_backend( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) - -class EngineRuntimeDelegates: - @staticmethod - def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: - return EngineRuntimeFacade._safe_component_snapshot(component) - - def _build_stage_counters( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - return self.runtime_facade._build_stage_counters(request_registry, worker_state) - - def _build_engine_policy_snapshot( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - return self.runtime_facade._build_engine_policy_snapshot(request_registry, worker_state) - - async def _wait_for_engine_policy_admission( - self, - *, - request_id: str | None, - timeout_sec: float | None, - ) -> tuple[float, Dict[str, Any]]: - return await self.engine_policy_arbiter.wait_for_policy_admission( - request_id=request_id, - timeout_sec=timeout_sec, - ) - - def _build_stage_summary( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - return self.runtime_facade._build_stage_summary(request_registry, worker_state) - - def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: - self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec) +from GPT_SoVITS.TTS_infer_pack.unified_engine_api_delegates import EngineApiDelegates +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_delegates import EngineBridgeDelegates +from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime_delegates import EngineRuntimeDelegates + +__all__ = [ + "EngineApiDelegates", + "EngineBridgeDelegates", + "EngineRuntimeDelegates", +] diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py new file mode 100644 index 00000000..a71f7e4e --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from typing import Any, Callable, Dict + +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner, EngineTaskQueueOwner +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_executor import EngineStageExecutor +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker + + +class EngineStageOrchestrator: + def __init__( + self, + *, + executor: EngineStageExecutor, + scheduler_worker: UnifiedSchedulerWorker, + prepare_queue_owner: EngineTaskQueueOwner, + finalize_queue_owner: EngineTaskQueueOwner, + dispatch_queue_owner: EngineTaskQueueOwner, + decode_runtime_owner: EngineDecodeRuntimeOwner, + snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]], + ) -> None: + self.executor = executor + self.scheduler_worker = scheduler_worker + self.prepare_queue_owner = prepare_queue_owner + self.finalize_queue_owner = finalize_queue_owner + self.dispatch_queue_owner = dispatch_queue_owner + self.decode_runtime_owner = decode_runtime_owner + self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state + self._select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]] | None = None + self._mark_arbiter_tick: Callable[[str, str, bool], None] | None = None + self._wait_arbiter: Callable[[], None] | None = None + + def bind_arbiter( + self, + *, + notify_arbiter: Callable[[], None], + select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]], + mark_arbiter_tick: Callable[[str, str, bool], None], + wait_arbiter: Callable[[], None], + ) -> None: + self.executor.bind_notify_arbiter(notify_arbiter) + self._select_stage = select_stage + self._mark_arbiter_tick = mark_arbiter_tick + self._wait_arbiter = wait_arbiter + + def peek_queue_age_ms(self, queue_name: str) -> float: + if queue_name == "prepare": + return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") + if queue_name == "finalize": + return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time") + if queue_name == "decode_runtime_pending": + return self.decode_runtime_owner.pending_age_ms() + return self.dispatch_queue_owner.peek_oldest_age_ms("enqueue_time") + + def has_pending_work(self) -> bool: + if self.scheduler_worker.is_engine_decode_control_enabled(): + if self.decode_runtime_owner.has_pending_jobs(): + return True + if self.scheduler_worker.is_engine_decode_control_enabled() and self.snapshot_engine_decode_runtime_state().get( + "active_request_count", 0 + ) > 0: + return True + if self.prepare_queue_owner.has_items(): + return True + if self.finalize_queue_owner.has_items(): + return True + return self.dispatch_queue_owner.has_items() + + def run_engine_arbiter_loop(self) -> None: + if self._select_stage is None or self._mark_arbiter_tick is None or self._wait_arbiter is None: + raise RuntimeError("arbiter callbacks are not bound") + while True: + if not self.has_pending_work(): + self._mark_arbiter_tick("idle", "no_pending_work", True) + self._wait_arbiter() + continue + stage, reason, policy_snapshot, worker_state = self._select_stage() + policy_allowed = bool(policy_snapshot.get("allowed", True)) + executed = False + if stage == "prepare": + executed = self.executor.run_engine_prepare_once() + elif stage == "finalize": + executed = self.executor.run_engine_finalize_once() + elif stage == "decode_dispatch": + executed = self.executor.run_engine_dispatch_once(policy_snapshot, worker_state) + elif stage == "decode_runtime": + executed = self.executor.run_engine_decode_runtime_once() + if not executed: + self._mark_arbiter_tick("idle", f"{stage}_not_ready", policy_allowed) + self._wait_arbiter() + continue + self._mark_arbiter_tick(stage, reason, policy_allowed) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime_delegates.py new file mode 100644 index 00000000..96153196 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime_delegates.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Any, Dict + +from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade + + +class EngineRuntimeDelegates: + @staticmethod + def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: + return EngineRuntimeFacade._safe_component_snapshot(component) + + def _build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_stage_counters(request_registry, worker_state) + + def _build_engine_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_engine_policy_snapshot(request_registry, worker_state) + + async def _wait_for_engine_policy_admission( + self, + *, + request_id: str | None, + timeout_sec: float | None, + ) -> tuple[float, Dict[str, Any]]: + return await self.engine_policy_arbiter.wait_for_policy_admission( + request_id=request_id, + timeout_sec=timeout_sec, + ) + + def _build_stage_summary( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_stage_summary(request_registry, worker_state) + + def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: + self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py index 65f0befe..9aad2fb8 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py @@ -1,20 +1,19 @@ from __future__ import annotations import asyncio -import time -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional from GPT_SoVITS.TTS_infer_pack.TTS import TTS from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( EngineDecodeRuntimeOwner, EngineDispatchTask, - EngineGpuPrepareTask, - EngineStatus, EngineTaskQueueOwner, SchedulerFinalizeTask, SchedulerPendingJob, ) +from GPT_SoVITS.TTS_infer_pack.unified_engine_orchestration import EngineStageOrchestrator +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_executor import EngineStageExecutor from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker @@ -42,29 +41,36 @@ class EngineStageCoordinator: snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]], snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]], ) -> None: - self.tts = tts - self.scheduler_worker = scheduler_worker - self.prepare_queue_owner = prepare_queue_owner - self.finalize_queue_owner = finalize_queue_owner - self.dispatch_queue_owner = dispatch_queue_owner - self.decode_runtime_owner = decode_runtime_owner - self.update_request_state = update_request_state - self.merge_request_state_profile = merge_request_state_profile - self.fail_request_state = fail_request_state - self.get_engine_job = get_engine_job - self.register_engine_job = register_engine_job - self.fail_engine_jobs = fail_engine_jobs - self.complete_engine_job = complete_engine_job - self.add_engine_prefill_time = add_engine_prefill_time - self.add_engine_merge_time = add_engine_merge_time - self.add_engine_decode_time = add_engine_decode_time - self.enqueue_engine_finished_items = enqueue_engine_finished_items - self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state - self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state - self._notify_arbiter: Callable[[], None] | None = None - self._select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]] | None = None - self._mark_arbiter_tick: Callable[[str, str, bool], None] | None = None - self._wait_arbiter: Callable[[], None] | None = None + self.executor = EngineStageExecutor( + tts=tts, + scheduler_worker=scheduler_worker, + prepare_queue_owner=prepare_queue_owner, + finalize_queue_owner=finalize_queue_owner, + dispatch_queue_owner=dispatch_queue_owner, + decode_runtime_owner=decode_runtime_owner, + update_request_state=update_request_state, + merge_request_state_profile=merge_request_state_profile, + fail_request_state=fail_request_state, + get_engine_job=get_engine_job, + register_engine_job=register_engine_job, + fail_engine_jobs=fail_engine_jobs, + complete_engine_job=complete_engine_job, + add_engine_prefill_time=add_engine_prefill_time, + add_engine_merge_time=add_engine_merge_time, + add_engine_decode_time=add_engine_decode_time, + enqueue_engine_finished_items=enqueue_engine_finished_items, + snapshot_engine_dispatch_state=snapshot_engine_dispatch_state, + snapshot_engine_decode_runtime_state=snapshot_engine_decode_runtime_state, + ) + self.orchestrator = EngineStageOrchestrator( + executor=self.executor, + scheduler_worker=scheduler_worker, + prepare_queue_owner=prepare_queue_owner, + finalize_queue_owner=finalize_queue_owner, + dispatch_queue_owner=dispatch_queue_owner, + decode_runtime_owner=decode_runtime_owner, + snapshot_engine_decode_runtime_state=snapshot_engine_decode_runtime_state, + ) def bind_arbiter( self, @@ -74,57 +80,12 @@ class EngineStageCoordinator: mark_arbiter_tick: Callable[[str, str, bool], None], wait_arbiter: Callable[[], None], ) -> None: - self._notify_arbiter = notify_arbiter - self._select_stage = select_stage - self._mark_arbiter_tick = mark_arbiter_tick - self._wait_arbiter = wait_arbiter - - def notify_arbiter(self) -> None: - if self._notify_arbiter is not None: - self._notify_arbiter() - - @staticmethod - def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None: - if future.done(): - return - future.set_exception(error) - - def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) - except RuntimeError: - pass - - @staticmethod - def _resolve_prepare_future( - future: asyncio.Future, - payload: tuple[T2SRequestState, float, float], - ) -> None: - if future.done(): - return - future.set_result(payload) - - def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) - except RuntimeError: - pass - - def _notify_prepare_result( - self, - task: EngineGpuPrepareTask, - payload: tuple[T2SRequestState, float, float], - ) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload) - except RuntimeError: - pass + self.orchestrator.bind_arbiter( + notify_arbiter=notify_arbiter, + select_stage=select_stage, + mark_arbiter_tick=mark_arbiter_tick, + wait_arbiter=wait_arbiter, + ) async def prepare_state_via_engine_gpu_queue( self, @@ -133,59 +94,17 @@ class EngineStageCoordinator: prepare_submit_at: float, engine_request_id: str | None, ) -> tuple[T2SRequestState, float, float]: - cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) - if engine_request_id not in [None, ""]: - self.update_request_state( - str(engine_request_id), - EngineStatus.GPU_PREPARING, - { - "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), - "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), - "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), - "text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms), - }, - ) - loop = asyncio.get_running_loop() - done_future = loop.create_future() - task = EngineGpuPrepareTask( - request_id=spec.request_id, - cpu_stage=cpu_stage, - done_loop=loop, - done_future=done_future, - engine_request_id=engine_request_id or spec.request_id, - enqueue_time=time.perf_counter(), + return await self.executor.prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, ) - self.prepare_queue_owner.enqueue(task) - self.notify_arbiter() - state, prepare_exec_started_at, prepare_exec_finished_at = await done_future - return state, prepare_exec_started_at, prepare_exec_finished_at def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: - if not tasks: - return - for task in tasks: - job = self.get_engine_job(task.request_id) - if job is not None: - self.update_request_state( - job.engine_request_id, - EngineStatus.READY_FOR_FINALIZE, - { - "finish_reason": task.item.finish_reason, - "semantic_len": int(task.item.semantic_tokens.shape[0]), - "finish_idx": int(task.item.finish_idx), - }, - ) - self.finalize_queue_owner.enqueue_many(tasks) - self.notify_arbiter() + self.executor.enqueue_worker_finished_for_finalize(tasks) def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: - finalize_policy = self.scheduler_worker.get_finalize_batch_policy() - return self.finalize_queue_owner.take_finalize_batch( - finalize_mode=str(finalize_policy.get("finalize_mode", "async")), - batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)), - batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)), - use_vocoder=bool(self.tts.configs.use_vocoder), - ) + return self.executor.take_engine_finalize_batch_nonblocking() async def enqueue_prepared_state_for_dispatch( self, @@ -201,220 +120,36 @@ class EngineStageCoordinator: engine_request_id: str | None, timeout_sec: float | None, ) -> EngineDispatchTask: - task = EngineDispatchTask( - request_id=state.request_id, + return await self.executor.enqueue_prepared_state_for_dispatch( state=state, - speed_factor=float(speed_factor), - sample_steps=int(sample_steps), + speed_factor=speed_factor, + sample_steps=sample_steps, media_type=media_type, - prepare_wall_ms=float(prepare_wall_ms), - prepare_profile_total_ms=float(prepare_profile_total_ms), + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, done_loop=done_loop, done_future=done_future, - engine_request_id=engine_request_id or state.request_id, + engine_request_id=engine_request_id, timeout_sec=timeout_sec, - enqueue_time=time.perf_counter(), ) - self.dispatch_queue_owner.enqueue(task) - self.notify_arbiter() - self.merge_request_state_profile( - task.engine_request_id or task.request_id, - { - "engine_dispatch_queue_depth_on_enqueue": int( - self.snapshot_engine_dispatch_state()["waiting_count"] - ), - }, - ) - return task def peek_queue_age_ms(self, queue_name: str) -> float: - if queue_name == "prepare": - return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") - if queue_name == "finalize": - return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time") - if queue_name == "decode_runtime_pending": - return self.decode_runtime_owner.pending_age_ms() - return self.dispatch_queue_owner.peek_oldest_age_ms("enqueue_time") + return self.orchestrator.peek_queue_age_ms(queue_name) def has_pending_work(self) -> bool: - if self.scheduler_worker.is_engine_decode_control_enabled(): - if self.decode_runtime_owner.has_pending_jobs(): - return True - if self.scheduler_worker.is_engine_decode_control_enabled() and self.snapshot_engine_decode_runtime_state().get( - "active_request_count", 0 - ) > 0: - return True - if self.prepare_queue_owner.has_items(): - return True - if self.finalize_queue_owner.has_items(): - return True - return self.dispatch_queue_owner.has_items() + return self.orchestrator.has_pending_work() def run_engine_prepare_once(self) -> bool: - task = self.prepare_queue_owner.pop_left() - if task is None: - return False - queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) - try: - state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( - self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) - ) - state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) - if task.engine_request_id not in [None, ""]: - self.merge_request_state_profile( - str(task.engine_request_id), - {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, - ) - self.prepare_queue_owner.mark_completed(1) - self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) - return True - except Exception as exc: - task.error = str(exc) - self.fail_request_state(task.engine_request_id or task.request_id, str(exc)) - self._notify_prepare_error(task, exc) - return True + return self.executor.run_engine_prepare_once() def run_engine_finalize_once(self) -> bool: - tasks = self.take_engine_finalize_batch_nonblocking() - if not tasks: - return False - self.scheduler_worker.begin_finalize_execution(len(tasks)) - try: - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] - for task in tasks: - job = self.get_engine_job(task.request_id) - if job is None: - continue - jobs_and_items.append((job, task.item)) - if not jobs_and_items: - return False - now = time.perf_counter() - for task in tasks: - job = self.get_engine_job(task.request_id) - if job is not None: - job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) - for job, item in jobs_and_items: - self.update_request_state( - job.engine_request_id, - EngineStatus.FINALIZING, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - }, - ) - synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items) - for job, _ in jobs_and_items: - job.synth_ms += float(synth_ms) - for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): - self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) - except Exception as exc: - self.fail_engine_jobs([task.request_id for task in tasks], str(exc)) - finally: - self.scheduler_worker.end_finalize_execution(len(tasks)) - self.finalize_queue_owner.mark_completed(len(tasks), notify=True) - return True + return self.executor.run_engine_finalize_once() def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: - if not bool(policy_snapshot.get("allowed", True)): - return False - dispatch_task = self.dispatch_queue_owner.pop_left() - if dispatch_task is None: - return False - dispatched_at = time.perf_counter() - dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0) - dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms) - dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms) - dispatch_task.engine_policy_snapshot = dict(policy_snapshot) - try: - worker_job = self.scheduler_worker.submit( - state=dispatch_task.state, - speed_factor=dispatch_task.speed_factor, - sample_steps=dispatch_task.sample_steps, - media_type=dispatch_task.media_type, - prepare_wall_ms=dispatch_task.prepare_wall_ms, - prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, - done_loop=dispatch_task.done_loop, - done_future=dispatch_task.done_future, - engine_request_id=dispatch_task.engine_request_id, - timeout_sec=dispatch_task.timeout_sec, - skip_capacity_wait=True, - admission_wait_ms_override=0.0, - admission_snapshot_override=dict(worker_state), - engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms, - engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms, - enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(), - ) - dispatch_task.worker_job = worker_job - self.register_engine_job(worker_job) - if self.scheduler_worker.is_engine_decode_control_enabled(): - self.decode_runtime_owner.enqueue_pending_job(worker_job) - self.notify_arbiter() - self.dispatch_queue_owner.mark_completed(1) - return True - except Exception as exc: - dispatch_task.error = str(exc) - self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc)) - self._notify_dispatch_error(dispatch_task, exc) - return True + return self.executor.run_engine_dispatch_once(policy_snapshot, worker_state) def run_engine_decode_runtime_once(self) -> bool: - if not self.scheduler_worker.is_engine_decode_control_enabled(): - return False - runtime_state = self.snapshot_engine_decode_runtime_state() - pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking( - wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0 - ) - result = self.scheduler_worker.execute_decode_cycle( - pending_jobs=pending_jobs, - active_batch=self.decode_runtime_owner.get_active_batch(), - external_bookkeeping=True, - ) - prefill_phase = dict(result.get("prefill_phase") or {}) - if prefill_phase.get("error"): - self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error"))) - else: - prefill_jobs = list(prefill_phase.get("pending_jobs") or []) - self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0))) - self.add_engine_merge_time( - [] if result.get("active_batch") is None else list(result["active_batch"].request_ids), - float(prefill_phase.get("merge_elapsed_s", 0.0)), - ) - self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or [])) - decode_phase = dict(result.get("decode_phase") or {}) - if decode_phase.get("error"): - self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error"))) - else: - self.add_engine_decode_time( - list(decode_phase.get("request_ids") or []), - float(decode_phase.get("decode_elapsed_s", 0.0)), - ) - self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or [])) - self.decode_runtime_owner.set_active_batch(result.get("active_batch")) - if result.get("executed", False): - self.decode_runtime_owner.refresh_state("engine_decode_cycle") - return bool(result.get("executed", False)) + return self.executor.run_engine_decode_runtime_once() def run_engine_arbiter_loop(self) -> None: - if self._select_stage is None or self._mark_arbiter_tick is None or self._wait_arbiter is None: - raise RuntimeError("arbiter callbacks are not bound") - while True: - if not self.has_pending_work(): - self._mark_arbiter_tick("idle", "no_pending_work", True) - self._wait_arbiter() - continue - stage, reason, policy_snapshot, worker_state = self._select_stage() - policy_allowed = bool(policy_snapshot.get("allowed", True)) - executed = False - if stage == "prepare": - executed = self.run_engine_prepare_once() - elif stage == "finalize": - executed = self.run_engine_finalize_once() - elif stage == "decode_dispatch": - executed = self.run_engine_dispatch_once(policy_snapshot, worker_state) - elif stage == "decode_runtime": - executed = self.run_engine_decode_runtime_once() - if not executed: - self._mark_arbiter_tick("idle", f"{stage}_not_ready", policy_allowed) - self._wait_arbiter() - continue - self._mark_arbiter_tick(stage, reason, policy_allowed) + self.orchestrator.run_engine_arbiter_loop() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py new file mode 100644 index 00000000..77274056 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py @@ -0,0 +1,358 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Any, Callable, Dict, List, Optional + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( + EngineDecodeRuntimeOwner, + EngineDispatchTask, + EngineGpuPrepareTask, + EngineStatus, + EngineTaskQueueOwner, + SchedulerFinalizeTask, + SchedulerPendingJob, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker + + +class EngineStageExecutor: + def __init__( + self, + *, + tts: TTS, + scheduler_worker: UnifiedSchedulerWorker, + prepare_queue_owner: EngineTaskQueueOwner, + finalize_queue_owner: EngineTaskQueueOwner, + dispatch_queue_owner: EngineTaskQueueOwner, + decode_runtime_owner: EngineDecodeRuntimeOwner, + update_request_state: Callable[[str, str, Optional[Dict[str, Any]]], None], + merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], + fail_request_state: Callable[[str, str], None], + get_engine_job: Callable[[str], SchedulerPendingJob | None], + register_engine_job: Callable[[SchedulerPendingJob], None], + fail_engine_jobs: Callable[[List[str], str], None], + complete_engine_job: Callable[..., None], + add_engine_prefill_time: Callable[[List[SchedulerPendingJob], float], None], + add_engine_merge_time: Callable[[List[str], float], None], + add_engine_decode_time: Callable[[List[str], float], None], + enqueue_engine_finished_items: Callable[[List[T2SFinishedItem]], None], + snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]], + snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]], + ) -> None: + self.tts = tts + self.scheduler_worker = scheduler_worker + self.prepare_queue_owner = prepare_queue_owner + self.finalize_queue_owner = finalize_queue_owner + self.dispatch_queue_owner = dispatch_queue_owner + self.decode_runtime_owner = decode_runtime_owner + self.update_request_state = update_request_state + self.merge_request_state_profile = merge_request_state_profile + self.fail_request_state = fail_request_state + self.get_engine_job = get_engine_job + self.register_engine_job = register_engine_job + self.fail_engine_jobs = fail_engine_jobs + self.complete_engine_job = complete_engine_job + self.add_engine_prefill_time = add_engine_prefill_time + self.add_engine_merge_time = add_engine_merge_time + self.add_engine_decode_time = add_engine_decode_time + self.enqueue_engine_finished_items = enqueue_engine_finished_items + self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state + self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state + self._notify_arbiter: Callable[[], None] | None = None + + def bind_notify_arbiter(self, notify_arbiter: Callable[[], None]) -> None: + self._notify_arbiter = notify_arbiter + + def notify_arbiter(self) -> None: + if self._notify_arbiter is not None: + self._notify_arbiter() + + @staticmethod + def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None: + if future.done(): + return + future.set_exception(error) + + def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + @staticmethod + def _resolve_prepare_future( + future: asyncio.Future, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if future.done(): + return + future.set_result(payload) + + def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + def _notify_prepare_result( + self, + task: EngineGpuPrepareTask, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload) + except RuntimeError: + pass + + async def prepare_state_via_engine_gpu_queue( + self, + *, + spec: Any, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + if engine_request_id not in [None, ""]: + self.update_request_state( + str(engine_request_id), + EngineStatus.GPU_PREPARING, + { + "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), + "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), + "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), + "text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms), + }, + ) + loop = asyncio.get_running_loop() + done_future = loop.create_future() + task = EngineGpuPrepareTask( + request_id=spec.request_id, + cpu_stage=cpu_stage, + done_loop=loop, + done_future=done_future, + engine_request_id=engine_request_id or spec.request_id, + enqueue_time=time.perf_counter(), + ) + self.prepare_queue_owner.enqueue(task) + self.notify_arbiter() + return await done_future + + def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is not None: + self.update_request_state( + job.engine_request_id, + EngineStatus.READY_FOR_FINALIZE, + { + "finish_reason": task.item.finish_reason, + "semantic_len": int(task.item.semantic_tokens.shape[0]), + "finish_idx": int(task.item.finish_idx), + }, + ) + self.finalize_queue_owner.enqueue_many(tasks) + self.notify_arbiter() + + def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + finalize_policy = self.scheduler_worker.get_finalize_batch_policy() + return self.finalize_queue_owner.take_finalize_batch( + finalize_mode=str(finalize_policy.get("finalize_mode", "async")), + batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)), + batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)), + use_vocoder=bool(self.tts.configs.use_vocoder), + ) + + async def enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + task = EngineDispatchTask( + request_id=state.request_id, + state=state, + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id or state.request_id, + timeout_sec=timeout_sec, + enqueue_time=time.perf_counter(), + ) + self.dispatch_queue_owner.enqueue(task) + self.notify_arbiter() + self.merge_request_state_profile( + task.engine_request_id or task.request_id, + { + "engine_dispatch_queue_depth_on_enqueue": int( + self.snapshot_engine_dispatch_state()["waiting_count"] + ), + }, + ) + return task + + def run_engine_prepare_once(self) -> bool: + task = self.prepare_queue_owner.pop_left() + if task is None: + return False + queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) + try: + state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( + self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) + ) + state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, + ) + self.prepare_queue_owner.mark_completed(1) + self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) + return True + except Exception as exc: + task.error = str(exc) + self.fail_request_state(task.engine_request_id or task.request_id, str(exc)) + self._notify_prepare_error(task, exc) + return True + + def run_engine_finalize_once(self) -> bool: + tasks = self.take_engine_finalize_batch_nonblocking() + if not tasks: + return False + self.scheduler_worker.begin_finalize_execution(len(tasks)) + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return False + now = time.perf_counter() + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is not None: + job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) + for job, item in jobs_and_items: + self.update_request_state( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items) + for job, _ in jobs_and_items: + job.synth_ms += float(synth_ms) + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self.fail_engine_jobs([task.request_id for task in tasks], str(exc)) + finally: + self.scheduler_worker.end_finalize_execution(len(tasks)) + self.finalize_queue_owner.mark_completed(len(tasks), notify=True) + return True + + def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + if not bool(policy_snapshot.get("allowed", True)): + return False + dispatch_task = self.dispatch_queue_owner.pop_left() + if dispatch_task is None: + return False + dispatched_at = time.perf_counter() + dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0) + dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_policy_snapshot = dict(policy_snapshot) + try: + worker_job = self.scheduler_worker.submit( + state=dispatch_task.state, + speed_factor=dispatch_task.speed_factor, + sample_steps=dispatch_task.sample_steps, + media_type=dispatch_task.media_type, + prepare_wall_ms=dispatch_task.prepare_wall_ms, + prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, + done_loop=dispatch_task.done_loop, + done_future=dispatch_task.done_future, + engine_request_id=dispatch_task.engine_request_id, + timeout_sec=dispatch_task.timeout_sec, + skip_capacity_wait=True, + admission_wait_ms_override=0.0, + admission_snapshot_override=dict(worker_state), + engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms, + engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms, + enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(), + ) + dispatch_task.worker_job = worker_job + self.register_engine_job(worker_job) + if self.scheduler_worker.is_engine_decode_control_enabled(): + self.decode_runtime_owner.enqueue_pending_job(worker_job) + self.notify_arbiter() + self.dispatch_queue_owner.mark_completed(1) + return True + except Exception as exc: + dispatch_task.error = str(exc) + self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc)) + self._notify_dispatch_error(dispatch_task, exc) + return True + + def run_engine_decode_runtime_once(self) -> bool: + if not self.scheduler_worker.is_engine_decode_control_enabled(): + return False + runtime_state = self.snapshot_engine_decode_runtime_state() + pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking( + wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0 + ) + result = self.scheduler_worker.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=self.decode_runtime_owner.get_active_batch(), + external_bookkeeping=True, + ) + prefill_phase = dict(result.get("prefill_phase") or {}) + if prefill_phase.get("error"): + self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error"))) + else: + prefill_jobs = list(prefill_phase.get("pending_jobs") or []) + self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0))) + self.add_engine_merge_time( + [] if result.get("active_batch") is None else list(result["active_batch"].request_ids), + float(prefill_phase.get("merge_elapsed_s", 0.0)), + ) + self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or [])) + decode_phase = dict(result.get("decode_phase") or {}) + if decode_phase.get("error"): + self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error"))) + else: + self.add_engine_decode_time( + list(decode_phase.get("request_ids") or []), + float(decode_phase.get("decode_elapsed_s", 0.0)), + ) + self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or [])) + self.decode_runtime_owner.set_active_batch(result.get("active_batch")) + if result.get("executed", False): + self.decode_runtime_owner.refresh_state("engine_decode_cycle") + return bool(result.get("executed", False))