diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_delegates.py new file mode 100644 index 00000000..f42ec233 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_delegates.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple + +from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, NormalizedEngineRequest + + +class EngineApiDelegates: + def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: + return self.api_facade._collect_request_summaries(request_ids) + + def _has_active_request(self, request_id: str) -> bool: + return self.api_facade._has_active_request(request_id) + + @staticmethod + def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: + return EngineApiFacade._build_request_meta(payload) + + @staticmethod + def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: + return EngineApiFacade._sum_profile_field(items, key) + + def _build_direct_segment_trace( + self, + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + return self.api_facade._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) + + def _build_direct_scheduler_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_direct_scheduler_profile(**kwargs) + + def _build_legacy_direct_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_legacy_direct_profile(**kwargs) + + def _build_scheduler_submit_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_scheduler_submit_profile(**kwargs) + + @staticmethod + def _format_ms_header(value: Any) -> str: + return EngineApiFacade._format_ms_header(value) + + def _build_scheduler_submit_headers( + self, + *, + request_id: str, + media_type: str, + sample_rate: int, + profile: Dict[str, Any], + ) -> Dict[str, str]: + return self.api_facade._build_scheduler_submit_headers( + request_id=request_id, + media_type=media_type, + sample_rate=sample_rate, + profile=profile, + ) + + def _build_scheduler_debug_request_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_scheduler_debug_request_profile(**kwargs) + + @staticmethod + def _build_scheduler_debug_batch_profile(**kwargs: Any) -> Dict[str, Any]: + return EngineApiFacade._build_scheduler_debug_batch_profile(**kwargs) + + def _normalize_lang(self, value: str | None) -> str | None: + return self.api_facade._normalize_lang(value) + + @staticmethod + def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: + return EngineApiFacade._aggregate_numeric_dicts(items) + + def _apply_default_reference(self, req: dict) -> dict: + return self.api_facade._apply_default_reference(req) + + def check_params(self, req: dict) -> Optional[str]: + return self.api_facade.check_params(req) + + @staticmethod + def _base_request_defaults() -> Dict[str, Any]: + return EngineApiFacade._base_request_defaults() + + def _normalize_engine_request( + self, + payload: dict | NormalizedEngineRequest, + *, + request_id: str | None = None, + normalize_streaming: bool = False, + error_prefix: str = "request 参数非法: ", + ) -> NormalizedEngineRequest: + return self.api_facade._normalize_engine_request( + payload, + request_id=request_id, + normalize_streaming=normalize_streaming, + error_prefix=error_prefix, + ) + + @staticmethod + def _normalize_streaming_mode(req: dict) -> dict: + return EngineApiFacade._normalize_streaming_mode(req) + + @staticmethod + def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: + return EngineApiFacade._is_aux_ref_enabled(aux_ref_audio_paths) + + def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: + return self.api_facade._select_direct_backend(normalized) + + def _iter_legacy_direct_tts_bytes( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> Generator[bytes, None, None]: + return self.api_facade._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: + return self.api_facade._should_use_scheduler_backend_for_direct(req) + + def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: + return self.api_facade._segment_direct_text(normalized) + + def _build_segment_request( + self, + normalized: NormalizedEngineRequest, + *, + request_id: str, + text: str, + ) -> NormalizedEngineRequest: + return self.api_facade._build_segment_request(normalized, request_id=request_id, text=text) + + async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: + return await self.api_facade._run_direct_tts_via_scheduler(normalized) + + def _run_legacy_direct_tts_blocking( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + return self.api_facade._run_legacy_direct_tts_blocking( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + async def _run_direct_tts_via_legacy_backend( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + return await self.api_facade._run_direct_tts_via_legacy_backend( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py index 536efbc5..d7740a52 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py @@ -1,310 +1,21 @@ from __future__ import annotations -import asyncio -import time -from typing import Any, Dict, List, Optional +from typing import Any -import numpy as np - -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState -from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner, EngineDispatchTask, EngineRequestState, EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_registry import EngineRegistryBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_runtime import EngineRuntimeBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_stage import EngineStageBridgeFacade class EngineBridgeFacade: def __init__(self, owner: Any) -> None: self.owner = owner + self.registry_bridge = EngineRegistryBridgeFacade(owner) + self.stage_bridge = EngineStageBridgeFacade(owner) + self.runtime_bridge = EngineRuntimeBridgeFacade(owner) - @property - def request_registry(self): - return self.owner.request_registry - - @property - def engine_prepare_queue_owner(self): - return self.owner.engine_prepare_queue_owner - - @property - def engine_finalize_queue_owner(self): - return self.owner.engine_finalize_queue_owner - - @property - def engine_dispatch_queue_owner(self): - return self.owner.engine_dispatch_queue_owner - - @property - def engine_decode_runtime_owner(self): - return self.owner.engine_decode_runtime_owner - - @property - def engine_job_registry(self): - return self.owner.engine_job_registry - - @property - def scheduler_worker(self): - return self.owner.scheduler_worker - - @property - def engine_stage_coordinator(self): - return self.owner.engine_stage_coordinator - - @property - def engine_policy_arbiter(self): - return self.owner.engine_policy_arbiter - - def _register_request_state( - self, - request_id: str, - api_mode: str, - backend: str, - media_type: str, - response_streaming: bool, - deadline_ts: float | None = None, - meta: Optional[Dict[str, Any]] = None, - ) -> EngineRequestState: - return self.request_registry.register( - request_id=request_id, - api_mode=api_mode, - backend=backend, - media_type=media_type, - response_streaming=response_streaming, - deadline_ts=deadline_ts, - meta=meta, - ) - - def _update_request_state( - self, - request_id: str, - status: str, - extra: Optional[Dict[str, Any]] = None, - ) -> None: - self.request_registry.update(request_id, status, extra) - - def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.request_registry.merge_profile(request_id, extra) - - def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.request_registry.complete(request_id, extra) - - def _fail_request_state(self, request_id: str, error: str) -> None: - self.request_registry.fail(request_id, error) - - def _snapshot_request_registry(self) -> Dict[str, Any]: - return self.request_registry.snapshot() - - def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: - return self.engine_prepare_queue_owner.snapshot(max_request_ids=16) - - def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: - return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) - - def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: - return self.engine_dispatch_queue_owner.snapshot( - max_request_ids=16, - extra={"last_policy_snapshot": dict(self.owner.engine_dispatch_last_snapshot or {})}, - ) - - def _register_engine_job(self, job: SchedulerPendingJob) -> None: - self.engine_job_registry.register(job, keep_job=True) - - def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: - return self.engine_job_registry.get(request_id) - - def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: - return self.engine_job_registry.pop(request_id) - - def _snapshot_engine_job_registry(self) -> Dict[str, Any]: - return self.engine_job_registry.snapshot(max_request_ids=32) - - def _is_engine_drained(self) -> bool: - prepare_empty = self.engine_prepare_queue_owner.is_drained() - dispatch_empty = self.engine_dispatch_queue_owner.is_drained() - finalize_empty = self.engine_finalize_queue_owner.is_drained() - decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() - job_empty = self.engine_job_registry.is_empty() - worker_state = self.scheduler_worker.snapshot() - return bool( - prepare_empty - and dispatch_empty - and finalize_empty - and decode_pending_empty - and job_empty - and self.engine_decode_runtime_owner.get_active_batch() is None - and int(worker_state.get("prepare_inflight", 0)) <= 0 - and int(worker_state.get("finalize_inflight", 0)) <= 0 - and int(worker_state.get("finalize_pending", 0)) <= 0 - ) - - def _record_engine_job_done(self, request_id: str) -> None: - self.engine_job_registry.mark_finished_and_remove(request_id) - self.scheduler_worker.record_external_job_done(request_id) - - def _complete_engine_job( - self, - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - audio_data: np.ndarray, - ) -> None: - completion_bridge = self.scheduler_worker.completion_bridge - completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) - completion_bridge.complete_job( - job, - runtime_request_id=job.engine_request_id, - runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate), - on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid), - ) - - def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: - if not request_ids: - return - completion_bridge = self.scheduler_worker.completion_bridge - for request_id in request_ids: - job = self._get_engine_job(request_id) - if job is None: - continue - completion_bridge.fail_job( - job, - error=error, - on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid), - ) - - def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - for job in jobs: - job.prefill_ms += delta_ms - - def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - for request_id in request_ids: - job = self._get_engine_job(request_id) - if job is not None: - job.merge_ms += delta_ms - - def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - activate_request_ids: List[str] = [] - for request_id in request_ids: - job = self._get_engine_job(request_id) - if job is None: - continue - if job.decode_steps == 0: - activate_request_ids.append(job.engine_request_id) - job.decode_ms += delta_ms - job.decode_steps += 1 - for engine_request_id in activate_request_ids: - self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None) - - def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: - if not items: - return - enqueued_at = time.perf_counter() - tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items] - self._enqueue_worker_finished_for_finalize(tasks) - - def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: - return self.engine_decode_runtime_owner.snapshot_pending_queue_state() - - @staticmethod - def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: - return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch) - - def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: - self.engine_decode_runtime_owner.refresh_state(last_event) - - def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: - if not snapshot: - return - if self.scheduler_worker.is_engine_decode_control_enabled(): - return - self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot) - - def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: - return self.engine_decode_runtime_owner.snapshot_state() - - def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: - return self.engine_policy_arbiter.snapshot_state() - - def _notify_engine_arbiter(self) -> None: - self.engine_policy_arbiter.notify() - - def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: - self.engine_stage_coordinator.decode_runtime_owner.enqueue_pending_job(job) - self._notify_engine_arbiter() - - def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.engine_stage_coordinator.decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch) - - def _peek_queue_age_ms(self, queue_name: str) -> float: - return self.engine_stage_coordinator.peek_queue_age_ms(queue_name) - - def _engine_has_pending_work(self) -> bool: - return self.engine_stage_coordinator.has_pending_work() - - async def _prepare_state_via_engine_gpu_queue( - self, - *, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - engine_request_id: str | None, - ) -> tuple[T2SRequestState, float, float]: - return await self.engine_stage_coordinator.prepare_state_via_engine_gpu_queue( - spec=spec, - prepare_submit_at=prepare_submit_at, - engine_request_id=engine_request_id, - ) - - def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: - self.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks) - - def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: - return self.engine_stage_coordinator.take_engine_finalize_batch_nonblocking() - - async def _enqueue_prepared_state_for_dispatch( - self, - *, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None, - done_future: asyncio.Future | None, - engine_request_id: str | None, - timeout_sec: float | None, - ) -> EngineDispatchTask: - return await self.engine_stage_coordinator.enqueue_prepared_state_for_dispatch( - state=state, - speed_factor=speed_factor, - sample_steps=sample_steps, - media_type=media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=done_loop, - done_future=done_future, - engine_request_id=engine_request_id, - timeout_sec=timeout_sec, - ) - - def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: - self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) - - def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: - stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage() - self.owner.engine_dispatch_last_snapshot = dict(policy_snapshot) - return stage, reason, policy_snapshot, worker_state - - def _run_engine_prepare_once(self) -> bool: - return self.engine_stage_coordinator.run_engine_prepare_once() - - def _run_engine_finalize_once(self) -> bool: - return self.engine_stage_coordinator.run_engine_finalize_once() - - def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: - return self.engine_stage_coordinator.run_engine_dispatch_once(policy_snapshot, worker_state) - - def _run_engine_decode_runtime_once(self) -> bool: - return self.engine_stage_coordinator.run_engine_decode_runtime_once() - - def _run_engine_arbiter_loop(self) -> None: - self.engine_stage_coordinator.run_engine_arbiter_loop() + def __getattr__(self, name: str) -> Any: + for component in (self.registry_bridge, self.stage_bridge, self.runtime_bridge): + if hasattr(component, name): + return getattr(component, name) + raise AttributeError(name) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py new file mode 100644 index 00000000..92714750 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import asyncio +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, EngineRequestState, SchedulerFinalizeTask, SchedulerPendingJob + + +class EngineBridgeDelegates: + def _register_request_state( + self, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + return self.bridge_facade._register_request_state( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=response_streaming, + deadline_ts=deadline_ts, + meta=meta, + ) + + def _update_request_state(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._update_request_state(request_id, status, extra) + + def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._merge_request_state_profile(request_id, extra) + + def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_prepare_state() + + def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_finalize_state() + + def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_dispatch_state() + + def _register_engine_job(self, job: SchedulerPendingJob) -> None: + self.bridge_facade._register_engine_job(job) + + def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.bridge_facade._get_engine_job(request_id) + + def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.bridge_facade._pop_engine_job(request_id) + + def _snapshot_engine_job_registry(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_job_registry() + + def _is_engine_drained(self) -> bool: + return self.bridge_facade._is_engine_drained() + + def _record_engine_job_done(self, request_id: str) -> None: + self.bridge_facade._record_engine_job_done(request_id) + + def _complete_engine_job( + self, + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + self.bridge_facade._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) + + def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: + self.bridge_facade._fail_engine_jobs(request_ids, error) + + def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + self.bridge_facade._add_engine_prefill_time(jobs, elapsed_s) + + def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + self.bridge_facade._add_engine_merge_time(request_ids, elapsed_s) + + def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + self.bridge_facade._add_engine_decode_time(request_ids, elapsed_s) + + def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: + self.bridge_facade._enqueue_engine_finished_items(items) + + def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_decode_pending_queue_state() + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + return EngineBridgeFacade._summarize_active_batch(active_batch) + + def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: + self.bridge_facade._refresh_engine_decode_runtime_state(last_event) + + def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: + self.bridge_facade._update_engine_decode_runtime_state(snapshot) + + def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_decode_runtime_state() + + def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_arbiter_state() + + def _notify_engine_arbiter(self) -> None: + self.bridge_facade._notify_engine_arbiter() + + def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: + self.bridge_facade._enqueue_engine_decode_pending_job(job) + + def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.bridge_facade._take_engine_decode_pending_jobs_nonblocking(wait_for_batch) + + def _peek_queue_age_ms(self, queue_name: str) -> float: + return self.bridge_facade._peek_queue_age_ms(queue_name) + + def _engine_has_pending_work(self) -> bool: + return self.bridge_facade._engine_has_pending_work() + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + return await self.bridge_facade._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, + ) + + def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + self.bridge_facade._enqueue_worker_finished_for_finalize(tasks) + + def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + return self.bridge_facade._take_engine_finalize_batch_nonblocking() + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + return await self.bridge_facade._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=speed_factor, + sample_steps=sample_steps, + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id, + timeout_sec=timeout_sec, + ) + + def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + self.bridge_facade._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) + + def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + return self.bridge_facade._select_engine_stage() + + def _run_engine_prepare_once(self) -> bool: + return self.bridge_facade._run_engine_prepare_once() + + def _run_engine_finalize_once(self) -> bool: + return self.bridge_facade._run_engine_finalize_once() + + def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + return self.bridge_facade._run_engine_dispatch_once(policy_snapshot, worker_state) + + def _run_engine_decode_runtime_once(self) -> bool: + return self.bridge_facade._run_engine_decode_runtime_once() + + def _run_engine_arbiter_loop(self) -> None: + self.bridge_facade._run_engine_arbiter_loop() + + def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._complete_request_state(request_id, extra) + + def _fail_request_state(self, request_id: str, error: str) -> None: + self.bridge_facade._fail_request_state(request_id, error) + + def _snapshot_request_registry(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_request_registry() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py new file mode 100644 index 00000000..88b8cc5d --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +import time +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineRequestState, EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob + + +class EngineRegistryBridgeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def request_registry(self): + return self.owner.request_registry + + @property + def engine_prepare_queue_owner(self): + return self.owner.engine_prepare_queue_owner + + @property + def engine_finalize_queue_owner(self): + return self.owner.engine_finalize_queue_owner + + @property + def engine_dispatch_queue_owner(self): + return self.owner.engine_dispatch_queue_owner + + @property + def engine_decode_runtime_owner(self): + return self.owner.engine_decode_runtime_owner + + @property + def engine_job_registry(self): + return self.owner.engine_job_registry + + @property + def scheduler_worker(self): + return self.owner.scheduler_worker + + def _register_request_state( + self, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + return self.request_registry.register( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=response_streaming, + deadline_ts=deadline_ts, + meta=meta, + ) + + def _update_request_state( + self, + request_id: str, + status: str, + extra: Optional[Dict[str, Any]] = None, + ) -> None: + self.request_registry.update(request_id, status, extra) + + def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.request_registry.merge_profile(request_id, extra) + + def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.request_registry.complete(request_id, extra) + + def _fail_request_state(self, request_id: str, error: str) -> None: + self.request_registry.fail(request_id, error) + + def _snapshot_request_registry(self) -> Dict[str, Any]: + return self.request_registry.snapshot() + + def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: + return self.engine_prepare_queue_owner.snapshot(max_request_ids=16) + + def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: + return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) + + def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: + return self.engine_dispatch_queue_owner.snapshot( + max_request_ids=16, + extra={"last_policy_snapshot": dict(self.owner.engine_dispatch_last_snapshot or {})}, + ) + + def _register_engine_job(self, job: SchedulerPendingJob) -> None: + self.engine_job_registry.register(job, keep_job=True) + + def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.engine_job_registry.get(request_id) + + def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.engine_job_registry.pop(request_id) + + def _snapshot_engine_job_registry(self) -> Dict[str, Any]: + return self.engine_job_registry.snapshot(max_request_ids=32) + + def _is_engine_drained(self) -> bool: + prepare_empty = self.engine_prepare_queue_owner.is_drained() + dispatch_empty = self.engine_dispatch_queue_owner.is_drained() + finalize_empty = self.engine_finalize_queue_owner.is_drained() + decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() + job_empty = self.engine_job_registry.is_empty() + worker_state = self.scheduler_worker.snapshot() + return bool( + prepare_empty + and dispatch_empty + and finalize_empty + and decode_pending_empty + and job_empty + and self.engine_decode_runtime_owner.get_active_batch() is None + and int(worker_state.get("prepare_inflight", 0)) <= 0 + and int(worker_state.get("finalize_inflight", 0)) <= 0 + and int(worker_state.get("finalize_pending", 0)) <= 0 + ) + + def _record_engine_job_done(self, request_id: str) -> None: + self.engine_job_registry.mark_finished_and_remove(request_id) + self.scheduler_worker.record_external_job_done(request_id) + + def _complete_engine_job( + self, + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + completion_bridge = self.scheduler_worker.completion_bridge + completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) + completion_bridge.complete_job( + job, + runtime_request_id=job.engine_request_id, + runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate), + on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid), + ) + + def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: + if not request_ids: + return + completion_bridge = self.scheduler_worker.completion_bridge + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is None: + continue + completion_bridge.fail_job( + job, + error=error, + on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid), + ) + + def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + for job in jobs: + job.prefill_ms += delta_ms + + def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + activate_request_ids: List[str] = [] + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is None: + continue + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) + job.decode_ms += delta_ms + job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None) + + def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items] + self.owner.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_runtime.py new file mode 100644 index 00000000..47be8b67 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_runtime.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import Any, Dict + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner + + +class EngineRuntimeBridgeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def engine_policy_arbiter(self): + return self.owner.engine_policy_arbiter + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch) + + def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: + return self.engine_policy_arbiter.snapshot_state() + + def _notify_engine_arbiter(self) -> None: + self.engine_policy_arbiter.notify() + + def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) + + def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage() + self.owner.engine_dispatch_last_snapshot = dict(policy_snapshot) + return stage, reason, policy_snapshot, worker_state diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py new file mode 100644 index 00000000..29b5aaab --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import asyncio +from typing import Any, Dict, List + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, SchedulerFinalizeTask, SchedulerPendingJob + + +class EngineStageBridgeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def engine_decode_runtime_owner(self): + return self.owner.engine_decode_runtime_owner + + @property + def scheduler_worker(self): + return self.owner.scheduler_worker + + @property + def engine_stage_coordinator(self): + return self.owner.engine_stage_coordinator + + def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: + return self.engine_decode_runtime_owner.snapshot_pending_queue_state() + + def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: + self.engine_decode_runtime_owner.refresh_state(last_event) + + def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: + if not snapshot: + return + if self.scheduler_worker.is_engine_decode_control_enabled(): + return + self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot) + + def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: + return self.engine_decode_runtime_owner.snapshot_state() + + def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: + self.engine_decode_runtime_owner.enqueue_pending_job(job) + self.owner.engine_policy_arbiter.notify() + + def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.engine_decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch) + + def _peek_queue_age_ms(self, queue_name: str) -> float: + return self.engine_stage_coordinator.peek_queue_age_ms(queue_name) + + def _engine_has_pending_work(self) -> bool: + return self.engine_stage_coordinator.has_pending_work() + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + return await self.engine_stage_coordinator.prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, + ) + + def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + self.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks) + + def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + return self.engine_stage_coordinator.take_engine_finalize_batch_nonblocking() + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + return await self.engine_stage_coordinator.enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=speed_factor, + sample_steps=sample_steps, + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id, + timeout_sec=timeout_sec, + ) + + def _run_engine_prepare_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_prepare_once() + + def _run_engine_finalize_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_finalize_once() + + def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + return self.engine_stage_coordinator.run_engine_dispatch_once(policy_snapshot, worker_state) + + def _run_engine_decode_runtime_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_decode_runtime_once() + + def _run_engine_arbiter_loop(self) -> None: + return self.engine_stage_coordinator.run_engine_arbiter_loop() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py index f68d3ede..d60a3bb8 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py @@ -1,401 +1,9 @@ -from __future__ import annotations - -import asyncio -from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple - -import numpy as np - -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState -from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade -from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade -from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, EngineDispatchTask, EngineRequestState, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerFinalizeTask, SchedulerPendingJob, SchedulerSubmitExecution -from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade - - -class EngineBridgeDelegates: - def _register_request_state( - self, - request_id: str, - api_mode: str, - backend: str, - media_type: str, - response_streaming: bool, - deadline_ts: float | None = None, - meta: Optional[Dict[str, Any]] = None, - ) -> EngineRequestState: - return self.bridge_facade._register_request_state( - request_id=request_id, - api_mode=api_mode, - backend=backend, - media_type=media_type, - response_streaming=response_streaming, - deadline_ts=deadline_ts, - meta=meta, - ) - - def _update_request_state(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.bridge_facade._update_request_state(request_id, status, extra) - - def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.bridge_facade._merge_request_state_profile(request_id, extra) - - def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_prepare_state() - - def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_finalize_state() - - def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_dispatch_state() - - def _register_engine_job(self, job: SchedulerPendingJob) -> None: - self.bridge_facade._register_engine_job(job) - - def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: - return self.bridge_facade._get_engine_job(request_id) - - def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: - return self.bridge_facade._pop_engine_job(request_id) - - def _snapshot_engine_job_registry(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_job_registry() - - def _is_engine_drained(self) -> bool: - return self.bridge_facade._is_engine_drained() - - def _record_engine_job_done(self, request_id: str) -> None: - self.bridge_facade._record_engine_job_done(request_id) - - def _complete_engine_job( - self, - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - audio_data: np.ndarray, - ) -> None: - self.bridge_facade._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) - - def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: - self.bridge_facade._fail_engine_jobs(request_ids, error) - - def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: - self.bridge_facade._add_engine_prefill_time(jobs, elapsed_s) - - def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: - self.bridge_facade._add_engine_merge_time(request_ids, elapsed_s) - - def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: - self.bridge_facade._add_engine_decode_time(request_ids, elapsed_s) - - def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: - self.bridge_facade._enqueue_engine_finished_items(items) - - def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_decode_pending_queue_state() - - @staticmethod - def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: - return EngineBridgeFacade._summarize_active_batch(active_batch) - - def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: - self.bridge_facade._refresh_engine_decode_runtime_state(last_event) - - def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: - self.bridge_facade._update_engine_decode_runtime_state(snapshot) - - def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_decode_runtime_state() - - def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_arbiter_state() - - def _notify_engine_arbiter(self) -> None: - self.bridge_facade._notify_engine_arbiter() - - def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: - self.bridge_facade._enqueue_engine_decode_pending_job(job) - - def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.bridge_facade._take_engine_decode_pending_jobs_nonblocking(wait_for_batch) - - def _peek_queue_age_ms(self, queue_name: str) -> float: - return self.bridge_facade._peek_queue_age_ms(queue_name) - - def _engine_has_pending_work(self) -> bool: - return self.bridge_facade._engine_has_pending_work() - - async def _prepare_state_via_engine_gpu_queue( - self, - *, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - engine_request_id: str | None, - ) -> tuple[T2SRequestState, float, float]: - return await self.bridge_facade._prepare_state_via_engine_gpu_queue( - spec=spec, - prepare_submit_at=prepare_submit_at, - engine_request_id=engine_request_id, - ) - - def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: - self.bridge_facade._enqueue_worker_finished_for_finalize(tasks) - - def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: - return self.bridge_facade._take_engine_finalize_batch_nonblocking() - - async def _enqueue_prepared_state_for_dispatch( - self, - *, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None, - done_future: asyncio.Future | None, - engine_request_id: str | None, - timeout_sec: float | None, - ) -> EngineDispatchTask: - return await self.bridge_facade._enqueue_prepared_state_for_dispatch( - state=state, - speed_factor=speed_factor, - sample_steps=sample_steps, - media_type=media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=done_loop, - done_future=done_future, - engine_request_id=engine_request_id, - timeout_sec=timeout_sec, - ) - - def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: - self.bridge_facade._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) - - def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: - return self.bridge_facade._select_engine_stage() - - def _run_engine_prepare_once(self) -> bool: - return self.bridge_facade._run_engine_prepare_once() - - def _run_engine_finalize_once(self) -> bool: - return self.bridge_facade._run_engine_finalize_once() - - def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: - return self.bridge_facade._run_engine_dispatch_once(policy_snapshot, worker_state) - - def _run_engine_decode_runtime_once(self) -> bool: - return self.bridge_facade._run_engine_decode_runtime_once() - - def _run_engine_arbiter_loop(self) -> None: - self.bridge_facade._run_engine_arbiter_loop() - - def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.bridge_facade._complete_request_state(request_id, extra) - - def _fail_request_state(self, request_id: str, error: str) -> None: - self.bridge_facade._fail_request_state(request_id, error) - - def _snapshot_request_registry(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_request_registry() - - -class EngineApiDelegates: - def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: - return self.api_facade._collect_request_summaries(request_ids) - - def _has_active_request(self, request_id: str) -> bool: - return self.api_facade._has_active_request(request_id) - - @staticmethod - def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: - return EngineApiFacade._build_request_meta(payload) - - @staticmethod - def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: - return EngineApiFacade._sum_profile_field(items, key) - - def _build_direct_segment_trace( - self, - segment_texts: Sequence[str], - prepare_profiles: Sequence[Dict[str, Any]], - worker_profiles: Sequence[Dict[str, Any]], - ) -> List[Dict[str, Any]]: - return self.api_facade._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) - - def _build_direct_scheduler_profile(self, **kwargs: Any) -> Dict[str, Any]: - return self.api_facade._build_direct_scheduler_profile(**kwargs) - - def _build_legacy_direct_profile(self, **kwargs: Any) -> Dict[str, Any]: - return self.api_facade._build_legacy_direct_profile(**kwargs) - - def _build_scheduler_submit_profile(self, **kwargs: Any) -> Dict[str, Any]: - return self.api_facade._build_scheduler_submit_profile(**kwargs) - - @staticmethod - def _format_ms_header(value: Any) -> str: - return EngineApiFacade._format_ms_header(value) - - def _build_scheduler_submit_headers( - self, - *, - request_id: str, - media_type: str, - sample_rate: int, - profile: Dict[str, Any], - ) -> Dict[str, str]: - return self.api_facade._build_scheduler_submit_headers( - request_id=request_id, - media_type=media_type, - sample_rate=sample_rate, - profile=profile, - ) - - def _build_scheduler_debug_request_profile(self, **kwargs: Any) -> Dict[str, Any]: - return self.api_facade._build_scheduler_debug_request_profile(**kwargs) - - @staticmethod - def _build_scheduler_debug_batch_profile(**kwargs: Any) -> Dict[str, Any]: - return EngineApiFacade._build_scheduler_debug_batch_profile(**kwargs) - - def _normalize_lang(self, value: str | None) -> str | None: - return self.api_facade._normalize_lang(value) - - @staticmethod - def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: - return EngineApiFacade._aggregate_numeric_dicts(items) - - def _apply_default_reference(self, req: dict) -> dict: - return self.api_facade._apply_default_reference(req) - - def check_params(self, req: dict) -> Optional[str]: - return self.api_facade.check_params(req) - - @staticmethod - def _base_request_defaults() -> Dict[str, Any]: - return EngineApiFacade._base_request_defaults() - - def _normalize_engine_request( - self, - payload: dict | NormalizedEngineRequest, - *, - request_id: str | None = None, - normalize_streaming: bool = False, - error_prefix: str = "request 参数非法: ", - ) -> NormalizedEngineRequest: - return self.api_facade._normalize_engine_request( - payload, - request_id=request_id, - normalize_streaming=normalize_streaming, - error_prefix=error_prefix, - ) - - @staticmethod - def _normalize_streaming_mode(req: dict) -> dict: - return EngineApiFacade._normalize_streaming_mode(req) - - @staticmethod - def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: - return EngineApiFacade._is_aux_ref_enabled(aux_ref_audio_paths) - - def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: - return self.api_facade._select_direct_backend(normalized) - - def _iter_legacy_direct_tts_bytes( - self, - normalized: NormalizedEngineRequest, - *, - backend: str, - fallback_reason: str | None, - ) -> Generator[bytes, None, None]: - return self.api_facade._iter_legacy_direct_tts_bytes( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) - - def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: - return self.api_facade._should_use_scheduler_backend_for_direct(req) - - def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: - return self.api_facade._segment_direct_text(normalized) - - def _build_segment_request( - self, - normalized: NormalizedEngineRequest, - *, - request_id: str, - text: str, - ) -> NormalizedEngineRequest: - return self.api_facade._build_segment_request(normalized, request_id=request_id, text=text) - - async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: - return await self.api_facade._run_direct_tts_via_scheduler(normalized) - - def _run_legacy_direct_tts_blocking( - self, - normalized: NormalizedEngineRequest, - *, - backend: str, - fallback_reason: str | None, - ) -> DirectTTSExecution: - return self.api_facade._run_legacy_direct_tts_blocking( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) - - async def _run_direct_tts_via_legacy_backend( - self, - normalized: NormalizedEngineRequest, - *, - backend: str, - fallback_reason: str | None, - ) -> DirectTTSExecution: - return await self.api_facade._run_direct_tts_via_legacy_backend( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) - -class EngineRuntimeDelegates: - @staticmethod - def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: - return EngineRuntimeFacade._safe_component_snapshot(component) - - def _build_stage_counters( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - return self.runtime_facade._build_stage_counters(request_registry, worker_state) - - def _build_engine_policy_snapshot( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - return self.runtime_facade._build_engine_policy_snapshot(request_registry, worker_state) - - async def _wait_for_engine_policy_admission( - self, - *, - request_id: str | None, - timeout_sec: float | None, - ) -> tuple[float, Dict[str, Any]]: - return await self.engine_policy_arbiter.wait_for_policy_admission( - request_id=request_id, - timeout_sec=timeout_sec, - ) - - def _build_stage_summary( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - return self.runtime_facade._build_stage_summary(request_registry, worker_state) - - def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: - self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec) +from GPT_SoVITS.TTS_infer_pack.unified_engine_api_delegates import EngineApiDelegates +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_delegates import EngineBridgeDelegates +from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime_delegates import EngineRuntimeDelegates + +__all__ = [ + "EngineApiDelegates", + "EngineBridgeDelegates", + "EngineRuntimeDelegates", +] diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py new file mode 100644 index 00000000..a71f7e4e --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from typing import Any, Callable, Dict + +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner, EngineTaskQueueOwner +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_executor import EngineStageExecutor +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker + + +class EngineStageOrchestrator: + def __init__( + self, + *, + executor: EngineStageExecutor, + scheduler_worker: UnifiedSchedulerWorker, + prepare_queue_owner: EngineTaskQueueOwner, + finalize_queue_owner: EngineTaskQueueOwner, + dispatch_queue_owner: EngineTaskQueueOwner, + decode_runtime_owner: EngineDecodeRuntimeOwner, + snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]], + ) -> None: + self.executor = executor + self.scheduler_worker = scheduler_worker + self.prepare_queue_owner = prepare_queue_owner + self.finalize_queue_owner = finalize_queue_owner + self.dispatch_queue_owner = dispatch_queue_owner + self.decode_runtime_owner = decode_runtime_owner + self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state + self._select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]] | None = None + self._mark_arbiter_tick: Callable[[str, str, bool], None] | None = None + self._wait_arbiter: Callable[[], None] | None = None + + def bind_arbiter( + self, + *, + notify_arbiter: Callable[[], None], + select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]], + mark_arbiter_tick: Callable[[str, str, bool], None], + wait_arbiter: Callable[[], None], + ) -> None: + self.executor.bind_notify_arbiter(notify_arbiter) + self._select_stage = select_stage + self._mark_arbiter_tick = mark_arbiter_tick + self._wait_arbiter = wait_arbiter + + def peek_queue_age_ms(self, queue_name: str) -> float: + if queue_name == "prepare": + return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") + if queue_name == "finalize": + return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time") + if queue_name == "decode_runtime_pending": + return self.decode_runtime_owner.pending_age_ms() + return self.dispatch_queue_owner.peek_oldest_age_ms("enqueue_time") + + def has_pending_work(self) -> bool: + if self.scheduler_worker.is_engine_decode_control_enabled(): + if self.decode_runtime_owner.has_pending_jobs(): + return True + if self.scheduler_worker.is_engine_decode_control_enabled() and self.snapshot_engine_decode_runtime_state().get( + "active_request_count", 0 + ) > 0: + return True + if self.prepare_queue_owner.has_items(): + return True + if self.finalize_queue_owner.has_items(): + return True + return self.dispatch_queue_owner.has_items() + + def run_engine_arbiter_loop(self) -> None: + if self._select_stage is None or self._mark_arbiter_tick is None or self._wait_arbiter is None: + raise RuntimeError("arbiter callbacks are not bound") + while True: + if not self.has_pending_work(): + self._mark_arbiter_tick("idle", "no_pending_work", True) + self._wait_arbiter() + continue + stage, reason, policy_snapshot, worker_state = self._select_stage() + policy_allowed = bool(policy_snapshot.get("allowed", True)) + executed = False + if stage == "prepare": + executed = self.executor.run_engine_prepare_once() + elif stage == "finalize": + executed = self.executor.run_engine_finalize_once() + elif stage == "decode_dispatch": + executed = self.executor.run_engine_dispatch_once(policy_snapshot, worker_state) + elif stage == "decode_runtime": + executed = self.executor.run_engine_decode_runtime_once() + if not executed: + self._mark_arbiter_tick("idle", f"{stage}_not_ready", policy_allowed) + self._wait_arbiter() + continue + self._mark_arbiter_tick(stage, reason, policy_allowed) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime_delegates.py new file mode 100644 index 00000000..96153196 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime_delegates.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Any, Dict + +from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade + + +class EngineRuntimeDelegates: + @staticmethod + def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: + return EngineRuntimeFacade._safe_component_snapshot(component) + + def _build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_stage_counters(request_registry, worker_state) + + def _build_engine_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_engine_policy_snapshot(request_registry, worker_state) + + async def _wait_for_engine_policy_admission( + self, + *, + request_id: str | None, + timeout_sec: float | None, + ) -> tuple[float, Dict[str, Any]]: + return await self.engine_policy_arbiter.wait_for_policy_admission( + request_id=request_id, + timeout_sec=timeout_sec, + ) + + def _build_stage_summary( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_stage_summary(request_registry, worker_state) + + def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: + self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py index 65f0befe..9aad2fb8 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py @@ -1,20 +1,19 @@ from __future__ import annotations import asyncio -import time -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional from GPT_SoVITS.TTS_infer_pack.TTS import TTS from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( EngineDecodeRuntimeOwner, EngineDispatchTask, - EngineGpuPrepareTask, - EngineStatus, EngineTaskQueueOwner, SchedulerFinalizeTask, SchedulerPendingJob, ) +from GPT_SoVITS.TTS_infer_pack.unified_engine_orchestration import EngineStageOrchestrator +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_executor import EngineStageExecutor from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker @@ -42,29 +41,36 @@ class EngineStageCoordinator: snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]], snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]], ) -> None: - self.tts = tts - self.scheduler_worker = scheduler_worker - self.prepare_queue_owner = prepare_queue_owner - self.finalize_queue_owner = finalize_queue_owner - self.dispatch_queue_owner = dispatch_queue_owner - self.decode_runtime_owner = decode_runtime_owner - self.update_request_state = update_request_state - self.merge_request_state_profile = merge_request_state_profile - self.fail_request_state = fail_request_state - self.get_engine_job = get_engine_job - self.register_engine_job = register_engine_job - self.fail_engine_jobs = fail_engine_jobs - self.complete_engine_job = complete_engine_job - self.add_engine_prefill_time = add_engine_prefill_time - self.add_engine_merge_time = add_engine_merge_time - self.add_engine_decode_time = add_engine_decode_time - self.enqueue_engine_finished_items = enqueue_engine_finished_items - self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state - self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state - self._notify_arbiter: Callable[[], None] | None = None - self._select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]] | None = None - self._mark_arbiter_tick: Callable[[str, str, bool], None] | None = None - self._wait_arbiter: Callable[[], None] | None = None + self.executor = EngineStageExecutor( + tts=tts, + scheduler_worker=scheduler_worker, + prepare_queue_owner=prepare_queue_owner, + finalize_queue_owner=finalize_queue_owner, + dispatch_queue_owner=dispatch_queue_owner, + decode_runtime_owner=decode_runtime_owner, + update_request_state=update_request_state, + merge_request_state_profile=merge_request_state_profile, + fail_request_state=fail_request_state, + get_engine_job=get_engine_job, + register_engine_job=register_engine_job, + fail_engine_jobs=fail_engine_jobs, + complete_engine_job=complete_engine_job, + add_engine_prefill_time=add_engine_prefill_time, + add_engine_merge_time=add_engine_merge_time, + add_engine_decode_time=add_engine_decode_time, + enqueue_engine_finished_items=enqueue_engine_finished_items, + snapshot_engine_dispatch_state=snapshot_engine_dispatch_state, + snapshot_engine_decode_runtime_state=snapshot_engine_decode_runtime_state, + ) + self.orchestrator = EngineStageOrchestrator( + executor=self.executor, + scheduler_worker=scheduler_worker, + prepare_queue_owner=prepare_queue_owner, + finalize_queue_owner=finalize_queue_owner, + dispatch_queue_owner=dispatch_queue_owner, + decode_runtime_owner=decode_runtime_owner, + snapshot_engine_decode_runtime_state=snapshot_engine_decode_runtime_state, + ) def bind_arbiter( self, @@ -74,57 +80,12 @@ class EngineStageCoordinator: mark_arbiter_tick: Callable[[str, str, bool], None], wait_arbiter: Callable[[], None], ) -> None: - self._notify_arbiter = notify_arbiter - self._select_stage = select_stage - self._mark_arbiter_tick = mark_arbiter_tick - self._wait_arbiter = wait_arbiter - - def notify_arbiter(self) -> None: - if self._notify_arbiter is not None: - self._notify_arbiter() - - @staticmethod - def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None: - if future.done(): - return - future.set_exception(error) - - def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) - except RuntimeError: - pass - - @staticmethod - def _resolve_prepare_future( - future: asyncio.Future, - payload: tuple[T2SRequestState, float, float], - ) -> None: - if future.done(): - return - future.set_result(payload) - - def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) - except RuntimeError: - pass - - def _notify_prepare_result( - self, - task: EngineGpuPrepareTask, - payload: tuple[T2SRequestState, float, float], - ) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload) - except RuntimeError: - pass + self.orchestrator.bind_arbiter( + notify_arbiter=notify_arbiter, + select_stage=select_stage, + mark_arbiter_tick=mark_arbiter_tick, + wait_arbiter=wait_arbiter, + ) async def prepare_state_via_engine_gpu_queue( self, @@ -133,59 +94,17 @@ class EngineStageCoordinator: prepare_submit_at: float, engine_request_id: str | None, ) -> tuple[T2SRequestState, float, float]: - cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) - if engine_request_id not in [None, ""]: - self.update_request_state( - str(engine_request_id), - EngineStatus.GPU_PREPARING, - { - "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), - "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), - "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), - "text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms), - }, - ) - loop = asyncio.get_running_loop() - done_future = loop.create_future() - task = EngineGpuPrepareTask( - request_id=spec.request_id, - cpu_stage=cpu_stage, - done_loop=loop, - done_future=done_future, - engine_request_id=engine_request_id or spec.request_id, - enqueue_time=time.perf_counter(), + return await self.executor.prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, ) - self.prepare_queue_owner.enqueue(task) - self.notify_arbiter() - state, prepare_exec_started_at, prepare_exec_finished_at = await done_future - return state, prepare_exec_started_at, prepare_exec_finished_at def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: - if not tasks: - return - for task in tasks: - job = self.get_engine_job(task.request_id) - if job is not None: - self.update_request_state( - job.engine_request_id, - EngineStatus.READY_FOR_FINALIZE, - { - "finish_reason": task.item.finish_reason, - "semantic_len": int(task.item.semantic_tokens.shape[0]), - "finish_idx": int(task.item.finish_idx), - }, - ) - self.finalize_queue_owner.enqueue_many(tasks) - self.notify_arbiter() + self.executor.enqueue_worker_finished_for_finalize(tasks) def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: - finalize_policy = self.scheduler_worker.get_finalize_batch_policy() - return self.finalize_queue_owner.take_finalize_batch( - finalize_mode=str(finalize_policy.get("finalize_mode", "async")), - batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)), - batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)), - use_vocoder=bool(self.tts.configs.use_vocoder), - ) + return self.executor.take_engine_finalize_batch_nonblocking() async def enqueue_prepared_state_for_dispatch( self, @@ -201,220 +120,36 @@ class EngineStageCoordinator: engine_request_id: str | None, timeout_sec: float | None, ) -> EngineDispatchTask: - task = EngineDispatchTask( - request_id=state.request_id, + return await self.executor.enqueue_prepared_state_for_dispatch( state=state, - speed_factor=float(speed_factor), - sample_steps=int(sample_steps), + speed_factor=speed_factor, + sample_steps=sample_steps, media_type=media_type, - prepare_wall_ms=float(prepare_wall_ms), - prepare_profile_total_ms=float(prepare_profile_total_ms), + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, done_loop=done_loop, done_future=done_future, - engine_request_id=engine_request_id or state.request_id, + engine_request_id=engine_request_id, timeout_sec=timeout_sec, - enqueue_time=time.perf_counter(), ) - self.dispatch_queue_owner.enqueue(task) - self.notify_arbiter() - self.merge_request_state_profile( - task.engine_request_id or task.request_id, - { - "engine_dispatch_queue_depth_on_enqueue": int( - self.snapshot_engine_dispatch_state()["waiting_count"] - ), - }, - ) - return task def peek_queue_age_ms(self, queue_name: str) -> float: - if queue_name == "prepare": - return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") - if queue_name == "finalize": - return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time") - if queue_name == "decode_runtime_pending": - return self.decode_runtime_owner.pending_age_ms() - return self.dispatch_queue_owner.peek_oldest_age_ms("enqueue_time") + return self.orchestrator.peek_queue_age_ms(queue_name) def has_pending_work(self) -> bool: - if self.scheduler_worker.is_engine_decode_control_enabled(): - if self.decode_runtime_owner.has_pending_jobs(): - return True - if self.scheduler_worker.is_engine_decode_control_enabled() and self.snapshot_engine_decode_runtime_state().get( - "active_request_count", 0 - ) > 0: - return True - if self.prepare_queue_owner.has_items(): - return True - if self.finalize_queue_owner.has_items(): - return True - return self.dispatch_queue_owner.has_items() + return self.orchestrator.has_pending_work() def run_engine_prepare_once(self) -> bool: - task = self.prepare_queue_owner.pop_left() - if task is None: - return False - queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) - try: - state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( - self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) - ) - state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) - if task.engine_request_id not in [None, ""]: - self.merge_request_state_profile( - str(task.engine_request_id), - {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, - ) - self.prepare_queue_owner.mark_completed(1) - self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) - return True - except Exception as exc: - task.error = str(exc) - self.fail_request_state(task.engine_request_id or task.request_id, str(exc)) - self._notify_prepare_error(task, exc) - return True + return self.executor.run_engine_prepare_once() def run_engine_finalize_once(self) -> bool: - tasks = self.take_engine_finalize_batch_nonblocking() - if not tasks: - return False - self.scheduler_worker.begin_finalize_execution(len(tasks)) - try: - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] - for task in tasks: - job = self.get_engine_job(task.request_id) - if job is None: - continue - jobs_and_items.append((job, task.item)) - if not jobs_and_items: - return False - now = time.perf_counter() - for task in tasks: - job = self.get_engine_job(task.request_id) - if job is not None: - job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) - for job, item in jobs_and_items: - self.update_request_state( - job.engine_request_id, - EngineStatus.FINALIZING, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - }, - ) - synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items) - for job, _ in jobs_and_items: - job.synth_ms += float(synth_ms) - for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): - self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) - except Exception as exc: - self.fail_engine_jobs([task.request_id for task in tasks], str(exc)) - finally: - self.scheduler_worker.end_finalize_execution(len(tasks)) - self.finalize_queue_owner.mark_completed(len(tasks), notify=True) - return True + return self.executor.run_engine_finalize_once() def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: - if not bool(policy_snapshot.get("allowed", True)): - return False - dispatch_task = self.dispatch_queue_owner.pop_left() - if dispatch_task is None: - return False - dispatched_at = time.perf_counter() - dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0) - dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms) - dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms) - dispatch_task.engine_policy_snapshot = dict(policy_snapshot) - try: - worker_job = self.scheduler_worker.submit( - state=dispatch_task.state, - speed_factor=dispatch_task.speed_factor, - sample_steps=dispatch_task.sample_steps, - media_type=dispatch_task.media_type, - prepare_wall_ms=dispatch_task.prepare_wall_ms, - prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, - done_loop=dispatch_task.done_loop, - done_future=dispatch_task.done_future, - engine_request_id=dispatch_task.engine_request_id, - timeout_sec=dispatch_task.timeout_sec, - skip_capacity_wait=True, - admission_wait_ms_override=0.0, - admission_snapshot_override=dict(worker_state), - engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms, - engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms, - enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(), - ) - dispatch_task.worker_job = worker_job - self.register_engine_job(worker_job) - if self.scheduler_worker.is_engine_decode_control_enabled(): - self.decode_runtime_owner.enqueue_pending_job(worker_job) - self.notify_arbiter() - self.dispatch_queue_owner.mark_completed(1) - return True - except Exception as exc: - dispatch_task.error = str(exc) - self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc)) - self._notify_dispatch_error(dispatch_task, exc) - return True + return self.executor.run_engine_dispatch_once(policy_snapshot, worker_state) def run_engine_decode_runtime_once(self) -> bool: - if not self.scheduler_worker.is_engine_decode_control_enabled(): - return False - runtime_state = self.snapshot_engine_decode_runtime_state() - pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking( - wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0 - ) - result = self.scheduler_worker.execute_decode_cycle( - pending_jobs=pending_jobs, - active_batch=self.decode_runtime_owner.get_active_batch(), - external_bookkeeping=True, - ) - prefill_phase = dict(result.get("prefill_phase") or {}) - if prefill_phase.get("error"): - self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error"))) - else: - prefill_jobs = list(prefill_phase.get("pending_jobs") or []) - self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0))) - self.add_engine_merge_time( - [] if result.get("active_batch") is None else list(result["active_batch"].request_ids), - float(prefill_phase.get("merge_elapsed_s", 0.0)), - ) - self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or [])) - decode_phase = dict(result.get("decode_phase") or {}) - if decode_phase.get("error"): - self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error"))) - else: - self.add_engine_decode_time( - list(decode_phase.get("request_ids") or []), - float(decode_phase.get("decode_elapsed_s", 0.0)), - ) - self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or [])) - self.decode_runtime_owner.set_active_batch(result.get("active_batch")) - if result.get("executed", False): - self.decode_runtime_owner.refresh_state("engine_decode_cycle") - return bool(result.get("executed", False)) + return self.executor.run_engine_decode_runtime_once() def run_engine_arbiter_loop(self) -> None: - if self._select_stage is None or self._mark_arbiter_tick is None or self._wait_arbiter is None: - raise RuntimeError("arbiter callbacks are not bound") - while True: - if not self.has_pending_work(): - self._mark_arbiter_tick("idle", "no_pending_work", True) - self._wait_arbiter() - continue - stage, reason, policy_snapshot, worker_state = self._select_stage() - policy_allowed = bool(policy_snapshot.get("allowed", True)) - executed = False - if stage == "prepare": - executed = self.run_engine_prepare_once() - elif stage == "finalize": - executed = self.run_engine_finalize_once() - elif stage == "decode_dispatch": - executed = self.run_engine_dispatch_once(policy_snapshot, worker_state) - elif stage == "decode_runtime": - executed = self.run_engine_decode_runtime_once() - if not executed: - self._mark_arbiter_tick("idle", f"{stage}_not_ready", policy_allowed) - self._wait_arbiter() - continue - self._mark_arbiter_tick(stage, reason, policy_allowed) + self.orchestrator.run_engine_arbiter_loop() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py new file mode 100644 index 00000000..77274056 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py @@ -0,0 +1,358 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Any, Callable, Dict, List, Optional + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( + EngineDecodeRuntimeOwner, + EngineDispatchTask, + EngineGpuPrepareTask, + EngineStatus, + EngineTaskQueueOwner, + SchedulerFinalizeTask, + SchedulerPendingJob, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker + + +class EngineStageExecutor: + def __init__( + self, + *, + tts: TTS, + scheduler_worker: UnifiedSchedulerWorker, + prepare_queue_owner: EngineTaskQueueOwner, + finalize_queue_owner: EngineTaskQueueOwner, + dispatch_queue_owner: EngineTaskQueueOwner, + decode_runtime_owner: EngineDecodeRuntimeOwner, + update_request_state: Callable[[str, str, Optional[Dict[str, Any]]], None], + merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], + fail_request_state: Callable[[str, str], None], + get_engine_job: Callable[[str], SchedulerPendingJob | None], + register_engine_job: Callable[[SchedulerPendingJob], None], + fail_engine_jobs: Callable[[List[str], str], None], + complete_engine_job: Callable[..., None], + add_engine_prefill_time: Callable[[List[SchedulerPendingJob], float], None], + add_engine_merge_time: Callable[[List[str], float], None], + add_engine_decode_time: Callable[[List[str], float], None], + enqueue_engine_finished_items: Callable[[List[T2SFinishedItem]], None], + snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]], + snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]], + ) -> None: + self.tts = tts + self.scheduler_worker = scheduler_worker + self.prepare_queue_owner = prepare_queue_owner + self.finalize_queue_owner = finalize_queue_owner + self.dispatch_queue_owner = dispatch_queue_owner + self.decode_runtime_owner = decode_runtime_owner + self.update_request_state = update_request_state + self.merge_request_state_profile = merge_request_state_profile + self.fail_request_state = fail_request_state + self.get_engine_job = get_engine_job + self.register_engine_job = register_engine_job + self.fail_engine_jobs = fail_engine_jobs + self.complete_engine_job = complete_engine_job + self.add_engine_prefill_time = add_engine_prefill_time + self.add_engine_merge_time = add_engine_merge_time + self.add_engine_decode_time = add_engine_decode_time + self.enqueue_engine_finished_items = enqueue_engine_finished_items + self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state + self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state + self._notify_arbiter: Callable[[], None] | None = None + + def bind_notify_arbiter(self, notify_arbiter: Callable[[], None]) -> None: + self._notify_arbiter = notify_arbiter + + def notify_arbiter(self) -> None: + if self._notify_arbiter is not None: + self._notify_arbiter() + + @staticmethod + def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None: + if future.done(): + return + future.set_exception(error) + + def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + @staticmethod + def _resolve_prepare_future( + future: asyncio.Future, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if future.done(): + return + future.set_result(payload) + + def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + def _notify_prepare_result( + self, + task: EngineGpuPrepareTask, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload) + except RuntimeError: + pass + + async def prepare_state_via_engine_gpu_queue( + self, + *, + spec: Any, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + if engine_request_id not in [None, ""]: + self.update_request_state( + str(engine_request_id), + EngineStatus.GPU_PREPARING, + { + "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), + "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), + "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), + "text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms), + }, + ) + loop = asyncio.get_running_loop() + done_future = loop.create_future() + task = EngineGpuPrepareTask( + request_id=spec.request_id, + cpu_stage=cpu_stage, + done_loop=loop, + done_future=done_future, + engine_request_id=engine_request_id or spec.request_id, + enqueue_time=time.perf_counter(), + ) + self.prepare_queue_owner.enqueue(task) + self.notify_arbiter() + return await done_future + + def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is not None: + self.update_request_state( + job.engine_request_id, + EngineStatus.READY_FOR_FINALIZE, + { + "finish_reason": task.item.finish_reason, + "semantic_len": int(task.item.semantic_tokens.shape[0]), + "finish_idx": int(task.item.finish_idx), + }, + ) + self.finalize_queue_owner.enqueue_many(tasks) + self.notify_arbiter() + + def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + finalize_policy = self.scheduler_worker.get_finalize_batch_policy() + return self.finalize_queue_owner.take_finalize_batch( + finalize_mode=str(finalize_policy.get("finalize_mode", "async")), + batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)), + batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)), + use_vocoder=bool(self.tts.configs.use_vocoder), + ) + + async def enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + task = EngineDispatchTask( + request_id=state.request_id, + state=state, + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id or state.request_id, + timeout_sec=timeout_sec, + enqueue_time=time.perf_counter(), + ) + self.dispatch_queue_owner.enqueue(task) + self.notify_arbiter() + self.merge_request_state_profile( + task.engine_request_id or task.request_id, + { + "engine_dispatch_queue_depth_on_enqueue": int( + self.snapshot_engine_dispatch_state()["waiting_count"] + ), + }, + ) + return task + + def run_engine_prepare_once(self) -> bool: + task = self.prepare_queue_owner.pop_left() + if task is None: + return False + queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) + try: + state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( + self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) + ) + state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, + ) + self.prepare_queue_owner.mark_completed(1) + self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) + return True + except Exception as exc: + task.error = str(exc) + self.fail_request_state(task.engine_request_id or task.request_id, str(exc)) + self._notify_prepare_error(task, exc) + return True + + def run_engine_finalize_once(self) -> bool: + tasks = self.take_engine_finalize_batch_nonblocking() + if not tasks: + return False + self.scheduler_worker.begin_finalize_execution(len(tasks)) + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return False + now = time.perf_counter() + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is not None: + job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) + for job, item in jobs_and_items: + self.update_request_state( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items) + for job, _ in jobs_and_items: + job.synth_ms += float(synth_ms) + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self.fail_engine_jobs([task.request_id for task in tasks], str(exc)) + finally: + self.scheduler_worker.end_finalize_execution(len(tasks)) + self.finalize_queue_owner.mark_completed(len(tasks), notify=True) + return True + + def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + if not bool(policy_snapshot.get("allowed", True)): + return False + dispatch_task = self.dispatch_queue_owner.pop_left() + if dispatch_task is None: + return False + dispatched_at = time.perf_counter() + dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0) + dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_policy_snapshot = dict(policy_snapshot) + try: + worker_job = self.scheduler_worker.submit( + state=dispatch_task.state, + speed_factor=dispatch_task.speed_factor, + sample_steps=dispatch_task.sample_steps, + media_type=dispatch_task.media_type, + prepare_wall_ms=dispatch_task.prepare_wall_ms, + prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, + done_loop=dispatch_task.done_loop, + done_future=dispatch_task.done_future, + engine_request_id=dispatch_task.engine_request_id, + timeout_sec=dispatch_task.timeout_sec, + skip_capacity_wait=True, + admission_wait_ms_override=0.0, + admission_snapshot_override=dict(worker_state), + engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms, + engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms, + enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(), + ) + dispatch_task.worker_job = worker_job + self.register_engine_job(worker_job) + if self.scheduler_worker.is_engine_decode_control_enabled(): + self.decode_runtime_owner.enqueue_pending_job(worker_job) + self.notify_arbiter() + self.dispatch_queue_owner.mark_completed(1) + return True + except Exception as exc: + dispatch_task.error = str(exc) + self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc)) + self._notify_dispatch_error(dispatch_task, exc) + return True + + def run_engine_decode_runtime_once(self) -> bool: + if not self.scheduler_worker.is_engine_decode_control_enabled(): + return False + runtime_state = self.snapshot_engine_decode_runtime_state() + pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking( + wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0 + ) + result = self.scheduler_worker.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=self.decode_runtime_owner.get_active_batch(), + external_bookkeeping=True, + ) + prefill_phase = dict(result.get("prefill_phase") or {}) + if prefill_phase.get("error"): + self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error"))) + else: + prefill_jobs = list(prefill_phase.get("pending_jobs") or []) + self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0))) + self.add_engine_merge_time( + [] if result.get("active_batch") is None else list(result["active_batch"].request_ids), + float(prefill_phase.get("merge_elapsed_s", 0.0)), + ) + self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or [])) + decode_phase = dict(result.get("decode_phase") or {}) + if decode_phase.get("error"): + self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error"))) + else: + self.add_engine_decode_time( + list(decode_phase.get("request_ids") or []), + float(decode_phase.get("decode_elapsed_s", 0.0)), + ) + self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or [])) + self.decode_runtime_owner.set_active_batch(result.get("active_batch")) + if result.get("executed", False): + self.decode_runtime_owner.refresh_state("engine_decode_cycle") + return bool(result.get("executed", False))