Add unified engine delegates and orchestration components for enhanced TTS processing

Introduce new modules including EngineApiDelegates, EngineBridgeDelegates, EngineRegistryBridgeFacade, EngineRuntimeBridgeFacade, EngineStageBridgeFacade, and EngineStageOrchestrator. These additions provide a structured approach to managing TTS requests, engine states, and orchestration, significantly improving the architecture and maintainability of the TTS system. The new components support asynchronous operations and enhance overall performance through better request handling and processing capabilities.
This commit is contained in:
baicai-1145 2026-03-11 18:35:47 +08:00
parent 800f01790e
commit b046a093d3
11 changed files with 1280 additions and 1025 deletions

View File

@ -0,0 +1,165 @@
from __future__ import annotations
from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple
from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, NormalizedEngineRequest
class EngineApiDelegates:
def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]:
return self.api_facade._collect_request_summaries(request_ids)
def _has_active_request(self, request_id: str) -> bool:
return self.api_facade._has_active_request(request_id)
@staticmethod
def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]:
return EngineApiFacade._build_request_meta(payload)
@staticmethod
def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float:
return EngineApiFacade._sum_profile_field(items, key)
def _build_direct_segment_trace(
self,
segment_texts: Sequence[str],
prepare_profiles: Sequence[Dict[str, Any]],
worker_profiles: Sequence[Dict[str, Any]],
) -> List[Dict[str, Any]]:
return self.api_facade._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles)
def _build_direct_scheduler_profile(self, **kwargs: Any) -> Dict[str, Any]:
return self.api_facade._build_direct_scheduler_profile(**kwargs)
def _build_legacy_direct_profile(self, **kwargs: Any) -> Dict[str, Any]:
return self.api_facade._build_legacy_direct_profile(**kwargs)
def _build_scheduler_submit_profile(self, **kwargs: Any) -> Dict[str, Any]:
return self.api_facade._build_scheduler_submit_profile(**kwargs)
@staticmethod
def _format_ms_header(value: Any) -> str:
return EngineApiFacade._format_ms_header(value)
def _build_scheduler_submit_headers(
self,
*,
request_id: str,
media_type: str,
sample_rate: int,
profile: Dict[str, Any],
) -> Dict[str, str]:
return self.api_facade._build_scheduler_submit_headers(
request_id=request_id,
media_type=media_type,
sample_rate=sample_rate,
profile=profile,
)
def _build_scheduler_debug_request_profile(self, **kwargs: Any) -> Dict[str, Any]:
return self.api_facade._build_scheduler_debug_request_profile(**kwargs)
@staticmethod
def _build_scheduler_debug_batch_profile(**kwargs: Any) -> Dict[str, Any]:
return EngineApiFacade._build_scheduler_debug_batch_profile(**kwargs)
def _normalize_lang(self, value: str | None) -> str | None:
return self.api_facade._normalize_lang(value)
@staticmethod
def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]:
return EngineApiFacade._aggregate_numeric_dicts(items)
def _apply_default_reference(self, req: dict) -> dict:
return self.api_facade._apply_default_reference(req)
def check_params(self, req: dict) -> Optional[str]:
return self.api_facade.check_params(req)
@staticmethod
def _base_request_defaults() -> Dict[str, Any]:
return EngineApiFacade._base_request_defaults()
def _normalize_engine_request(
self,
payload: dict | NormalizedEngineRequest,
*,
request_id: str | None = None,
normalize_streaming: bool = False,
error_prefix: str = "request 参数非法: ",
) -> NormalizedEngineRequest:
return self.api_facade._normalize_engine_request(
payload,
request_id=request_id,
normalize_streaming=normalize_streaming,
error_prefix=error_prefix,
)
@staticmethod
def _normalize_streaming_mode(req: dict) -> dict:
return EngineApiFacade._normalize_streaming_mode(req)
@staticmethod
def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool:
return EngineApiFacade._is_aux_ref_enabled(aux_ref_audio_paths)
def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]:
return self.api_facade._select_direct_backend(normalized)
def _iter_legacy_direct_tts_bytes(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> Generator[bytes, None, None]:
return self.api_facade._iter_legacy_direct_tts_bytes(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool:
return self.api_facade._should_use_scheduler_backend_for_direct(req)
def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]:
return self.api_facade._segment_direct_text(normalized)
def _build_segment_request(
self,
normalized: NormalizedEngineRequest,
*,
request_id: str,
text: str,
) -> NormalizedEngineRequest:
return self.api_facade._build_segment_request(normalized, request_id=request_id, text=text)
async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution:
return await self.api_facade._run_direct_tts_via_scheduler(normalized)
def _run_legacy_direct_tts_blocking(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> DirectTTSExecution:
return self.api_facade._run_legacy_direct_tts_blocking(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
async def _run_direct_tts_via_legacy_backend(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> DirectTTSExecution:
return await self.api_facade._run_direct_tts_via_legacy_backend(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)

View File

@ -1,310 +1,21 @@
from __future__ import annotations
import asyncio
import time
from typing import Any, Dict, List, Optional
from typing import Any
import numpy as np
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner, EngineDispatchTask, EngineRequestState, EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_registry import EngineRegistryBridgeFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_runtime import EngineRuntimeBridgeFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_stage import EngineStageBridgeFacade
class EngineBridgeFacade:
def __init__(self, owner: Any) -> None:
self.owner = owner
self.registry_bridge = EngineRegistryBridgeFacade(owner)
self.stage_bridge = EngineStageBridgeFacade(owner)
self.runtime_bridge = EngineRuntimeBridgeFacade(owner)
@property
def request_registry(self):
return self.owner.request_registry
@property
def engine_prepare_queue_owner(self):
return self.owner.engine_prepare_queue_owner
@property
def engine_finalize_queue_owner(self):
return self.owner.engine_finalize_queue_owner
@property
def engine_dispatch_queue_owner(self):
return self.owner.engine_dispatch_queue_owner
@property
def engine_decode_runtime_owner(self):
return self.owner.engine_decode_runtime_owner
@property
def engine_job_registry(self):
return self.owner.engine_job_registry
@property
def scheduler_worker(self):
return self.owner.scheduler_worker
@property
def engine_stage_coordinator(self):
return self.owner.engine_stage_coordinator
@property
def engine_policy_arbiter(self):
return self.owner.engine_policy_arbiter
def _register_request_state(
self,
request_id: str,
api_mode: str,
backend: str,
media_type: str,
response_streaming: bool,
deadline_ts: float | None = None,
meta: Optional[Dict[str, Any]] = None,
) -> EngineRequestState:
return self.request_registry.register(
request_id=request_id,
api_mode=api_mode,
backend=backend,
media_type=media_type,
response_streaming=response_streaming,
deadline_ts=deadline_ts,
meta=meta,
)
def _update_request_state(
self,
request_id: str,
status: str,
extra: Optional[Dict[str, Any]] = None,
) -> None:
self.request_registry.update(request_id, status, extra)
def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.request_registry.merge_profile(request_id, extra)
def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.request_registry.complete(request_id, extra)
def _fail_request_state(self, request_id: str, error: str) -> None:
self.request_registry.fail(request_id, error)
def _snapshot_request_registry(self) -> Dict[str, Any]:
return self.request_registry.snapshot()
def _snapshot_engine_prepare_state(self) -> Dict[str, Any]:
return self.engine_prepare_queue_owner.snapshot(max_request_ids=16)
def _snapshot_engine_finalize_state(self) -> Dict[str, Any]:
return self.engine_finalize_queue_owner.snapshot(max_request_ids=16)
def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]:
return self.engine_dispatch_queue_owner.snapshot(
max_request_ids=16,
extra={"last_policy_snapshot": dict(self.owner.engine_dispatch_last_snapshot or {})},
)
def _register_engine_job(self, job: SchedulerPendingJob) -> None:
self.engine_job_registry.register(job, keep_job=True)
def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
return self.engine_job_registry.get(request_id)
def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
return self.engine_job_registry.pop(request_id)
def _snapshot_engine_job_registry(self) -> Dict[str, Any]:
return self.engine_job_registry.snapshot(max_request_ids=32)
def _is_engine_drained(self) -> bool:
prepare_empty = self.engine_prepare_queue_owner.is_drained()
dispatch_empty = self.engine_dispatch_queue_owner.is_drained()
finalize_empty = self.engine_finalize_queue_owner.is_drained()
decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs()
job_empty = self.engine_job_registry.is_empty()
worker_state = self.scheduler_worker.snapshot()
return bool(
prepare_empty
and dispatch_empty
and finalize_empty
and decode_pending_empty
and job_empty
and self.engine_decode_runtime_owner.get_active_batch() is None
and int(worker_state.get("prepare_inflight", 0)) <= 0
and int(worker_state.get("finalize_inflight", 0)) <= 0
and int(worker_state.get("finalize_pending", 0)) <= 0
)
def _record_engine_job_done(self, request_id: str) -> None:
self.engine_job_registry.mark_finished_and_remove(request_id)
self.scheduler_worker.record_external_job_done(request_id)
def _complete_engine_job(
self,
job: SchedulerPendingJob,
item: T2SFinishedItem,
*,
sample_rate: int,
audio_data: np.ndarray,
) -> None:
completion_bridge = self.scheduler_worker.completion_bridge
completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data)
completion_bridge.complete_job(
job,
runtime_request_id=job.engine_request_id,
runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate),
on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid),
)
def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None:
if not request_ids:
return
completion_bridge = self.scheduler_worker.completion_bridge
for request_id in request_ids:
job = self._get_engine_job(request_id)
if job is None:
continue
completion_bridge.fail_job(
job,
error=error,
on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid),
)
def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
for job in jobs:
job.prefill_ms += delta_ms
def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
for request_id in request_ids:
job = self._get_engine_job(request_id)
if job is not None:
job.merge_ms += delta_ms
def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
activate_request_ids: List[str] = []
for request_id in request_ids:
job = self._get_engine_job(request_id)
if job is None:
continue
if job.decode_steps == 0:
activate_request_ids.append(job.engine_request_id)
job.decode_ms += delta_ms
job.decode_steps += 1
for engine_request_id in activate_request_ids:
self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None)
def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None:
if not items:
return
enqueued_at = time.perf_counter()
tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items]
self._enqueue_worker_finished_for_finalize(tasks)
def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]:
return self.engine_decode_runtime_owner.snapshot_pending_queue_state()
@staticmethod
def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]:
return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch)
def _refresh_engine_decode_runtime_state(self, last_event: str) -> None:
self.engine_decode_runtime_owner.refresh_state(last_event)
def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None:
if not snapshot:
return
if self.scheduler_worker.is_engine_decode_control_enabled():
return
self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot)
def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]:
return self.engine_decode_runtime_owner.snapshot_state()
def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]:
return self.engine_policy_arbiter.snapshot_state()
def _notify_engine_arbiter(self) -> None:
self.engine_policy_arbiter.notify()
def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None:
self.engine_stage_coordinator.decode_runtime_owner.enqueue_pending_job(job)
self._notify_engine_arbiter()
def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
return self.engine_stage_coordinator.decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch)
def _peek_queue_age_ms(self, queue_name: str) -> float:
return self.engine_stage_coordinator.peek_queue_age_ms(queue_name)
def _engine_has_pending_work(self) -> bool:
return self.engine_stage_coordinator.has_pending_work()
async def _prepare_state_via_engine_gpu_queue(
self,
*,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
engine_request_id: str | None,
) -> tuple[T2SRequestState, float, float]:
return await self.engine_stage_coordinator.prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=prepare_submit_at,
engine_request_id=engine_request_id,
)
def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
self.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks)
def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
return self.engine_stage_coordinator.take_engine_finalize_batch_nonblocking()
async def _enqueue_prepared_state_for_dispatch(
self,
*,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None,
done_future: asyncio.Future | None,
engine_request_id: str | None,
timeout_sec: float | None,
) -> EngineDispatchTask:
return await self.engine_stage_coordinator.enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=speed_factor,
sample_steps=sample_steps,
media_type=media_type,
prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=done_loop,
done_future=done_future,
engine_request_id=engine_request_id,
timeout_sec=timeout_sec,
)
def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed)
def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage()
self.owner.engine_dispatch_last_snapshot = dict(policy_snapshot)
return stage, reason, policy_snapshot, worker_state
def _run_engine_prepare_once(self) -> bool:
return self.engine_stage_coordinator.run_engine_prepare_once()
def _run_engine_finalize_once(self) -> bool:
return self.engine_stage_coordinator.run_engine_finalize_once()
def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
return self.engine_stage_coordinator.run_engine_dispatch_once(policy_snapshot, worker_state)
def _run_engine_decode_runtime_once(self) -> bool:
return self.engine_stage_coordinator.run_engine_decode_runtime_once()
def _run_engine_arbiter_loop(self) -> None:
self.engine_stage_coordinator.run_engine_arbiter_loop()
def __getattr__(self, name: str) -> Any:
for component in (self.registry_bridge, self.stage_bridge, self.runtime_bridge):
if hasattr(component, name):
return getattr(component, name)
raise AttributeError(name)

View File

@ -0,0 +1,200 @@
from __future__ import annotations
import asyncio
from typing import Any, Dict, List, Optional
import numpy as np
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, EngineRequestState, SchedulerFinalizeTask, SchedulerPendingJob
class EngineBridgeDelegates:
def _register_request_state(
self,
request_id: str,
api_mode: str,
backend: str,
media_type: str,
response_streaming: bool,
deadline_ts: float | None = None,
meta: Optional[Dict[str, Any]] = None,
) -> EngineRequestState:
return self.bridge_facade._register_request_state(
request_id=request_id,
api_mode=api_mode,
backend=backend,
media_type=media_type,
response_streaming=response_streaming,
deadline_ts=deadline_ts,
meta=meta,
)
def _update_request_state(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.bridge_facade._update_request_state(request_id, status, extra)
def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.bridge_facade._merge_request_state_profile(request_id, extra)
def _snapshot_engine_prepare_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_prepare_state()
def _snapshot_engine_finalize_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_finalize_state()
def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_dispatch_state()
def _register_engine_job(self, job: SchedulerPendingJob) -> None:
self.bridge_facade._register_engine_job(job)
def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
return self.bridge_facade._get_engine_job(request_id)
def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
return self.bridge_facade._pop_engine_job(request_id)
def _snapshot_engine_job_registry(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_job_registry()
def _is_engine_drained(self) -> bool:
return self.bridge_facade._is_engine_drained()
def _record_engine_job_done(self, request_id: str) -> None:
self.bridge_facade._record_engine_job_done(request_id)
def _complete_engine_job(
self,
job: SchedulerPendingJob,
item: T2SFinishedItem,
*,
sample_rate: int,
audio_data: np.ndarray,
) -> None:
self.bridge_facade._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None:
self.bridge_facade._fail_engine_jobs(request_ids, error)
def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None:
self.bridge_facade._add_engine_prefill_time(jobs, elapsed_s)
def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
self.bridge_facade._add_engine_merge_time(request_ids, elapsed_s)
def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
self.bridge_facade._add_engine_decode_time(request_ids, elapsed_s)
def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None:
self.bridge_facade._enqueue_engine_finished_items(items)
def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_decode_pending_queue_state()
@staticmethod
def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]:
return EngineBridgeFacade._summarize_active_batch(active_batch)
def _refresh_engine_decode_runtime_state(self, last_event: str) -> None:
self.bridge_facade._refresh_engine_decode_runtime_state(last_event)
def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None:
self.bridge_facade._update_engine_decode_runtime_state(snapshot)
def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_decode_runtime_state()
def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_arbiter_state()
def _notify_engine_arbiter(self) -> None:
self.bridge_facade._notify_engine_arbiter()
def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None:
self.bridge_facade._enqueue_engine_decode_pending_job(job)
def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
return self.bridge_facade._take_engine_decode_pending_jobs_nonblocking(wait_for_batch)
def _peek_queue_age_ms(self, queue_name: str) -> float:
return self.bridge_facade._peek_queue_age_ms(queue_name)
def _engine_has_pending_work(self) -> bool:
return self.bridge_facade._engine_has_pending_work()
async def _prepare_state_via_engine_gpu_queue(
self,
*,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
engine_request_id: str | None,
) -> tuple[T2SRequestState, float, float]:
return await self.bridge_facade._prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=prepare_submit_at,
engine_request_id=engine_request_id,
)
def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
self.bridge_facade._enqueue_worker_finished_for_finalize(tasks)
def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
return self.bridge_facade._take_engine_finalize_batch_nonblocking()
async def _enqueue_prepared_state_for_dispatch(
self,
*,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None,
done_future: asyncio.Future | None,
engine_request_id: str | None,
timeout_sec: float | None,
) -> EngineDispatchTask:
return await self.bridge_facade._enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=speed_factor,
sample_steps=sample_steps,
media_type=media_type,
prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=done_loop,
done_future=done_future,
engine_request_id=engine_request_id,
timeout_sec=timeout_sec,
)
def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
self.bridge_facade._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed)
def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
return self.bridge_facade._select_engine_stage()
def _run_engine_prepare_once(self) -> bool:
return self.bridge_facade._run_engine_prepare_once()
def _run_engine_finalize_once(self) -> bool:
return self.bridge_facade._run_engine_finalize_once()
def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
return self.bridge_facade._run_engine_dispatch_once(policy_snapshot, worker_state)
def _run_engine_decode_runtime_once(self) -> bool:
return self.bridge_facade._run_engine_decode_runtime_once()
def _run_engine_arbiter_loop(self) -> None:
self.bridge_facade._run_engine_arbiter_loop()
def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.bridge_facade._complete_request_state(request_id, extra)
def _fail_request_state(self, request_id: str, error: str) -> None:
self.bridge_facade._fail_request_state(request_id, error)
def _snapshot_request_registry(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_request_registry()

View File

@ -0,0 +1,193 @@
from __future__ import annotations
import time
from typing import Any, Dict, List, Optional
import numpy as np
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineRequestState, EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob
class EngineRegistryBridgeFacade:
def __init__(self, owner: Any) -> None:
self.owner = owner
@property
def request_registry(self):
return self.owner.request_registry
@property
def engine_prepare_queue_owner(self):
return self.owner.engine_prepare_queue_owner
@property
def engine_finalize_queue_owner(self):
return self.owner.engine_finalize_queue_owner
@property
def engine_dispatch_queue_owner(self):
return self.owner.engine_dispatch_queue_owner
@property
def engine_decode_runtime_owner(self):
return self.owner.engine_decode_runtime_owner
@property
def engine_job_registry(self):
return self.owner.engine_job_registry
@property
def scheduler_worker(self):
return self.owner.scheduler_worker
def _register_request_state(
self,
request_id: str,
api_mode: str,
backend: str,
media_type: str,
response_streaming: bool,
deadline_ts: float | None = None,
meta: Optional[Dict[str, Any]] = None,
) -> EngineRequestState:
return self.request_registry.register(
request_id=request_id,
api_mode=api_mode,
backend=backend,
media_type=media_type,
response_streaming=response_streaming,
deadline_ts=deadline_ts,
meta=meta,
)
def _update_request_state(
self,
request_id: str,
status: str,
extra: Optional[Dict[str, Any]] = None,
) -> None:
self.request_registry.update(request_id, status, extra)
def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.request_registry.merge_profile(request_id, extra)
def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.request_registry.complete(request_id, extra)
def _fail_request_state(self, request_id: str, error: str) -> None:
self.request_registry.fail(request_id, error)
def _snapshot_request_registry(self) -> Dict[str, Any]:
return self.request_registry.snapshot()
def _snapshot_engine_prepare_state(self) -> Dict[str, Any]:
return self.engine_prepare_queue_owner.snapshot(max_request_ids=16)
def _snapshot_engine_finalize_state(self) -> Dict[str, Any]:
return self.engine_finalize_queue_owner.snapshot(max_request_ids=16)
def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]:
return self.engine_dispatch_queue_owner.snapshot(
max_request_ids=16,
extra={"last_policy_snapshot": dict(self.owner.engine_dispatch_last_snapshot or {})},
)
def _register_engine_job(self, job: SchedulerPendingJob) -> None:
self.engine_job_registry.register(job, keep_job=True)
def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
return self.engine_job_registry.get(request_id)
def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
return self.engine_job_registry.pop(request_id)
def _snapshot_engine_job_registry(self) -> Dict[str, Any]:
return self.engine_job_registry.snapshot(max_request_ids=32)
def _is_engine_drained(self) -> bool:
prepare_empty = self.engine_prepare_queue_owner.is_drained()
dispatch_empty = self.engine_dispatch_queue_owner.is_drained()
finalize_empty = self.engine_finalize_queue_owner.is_drained()
decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs()
job_empty = self.engine_job_registry.is_empty()
worker_state = self.scheduler_worker.snapshot()
return bool(
prepare_empty
and dispatch_empty
and finalize_empty
and decode_pending_empty
and job_empty
and self.engine_decode_runtime_owner.get_active_batch() is None
and int(worker_state.get("prepare_inflight", 0)) <= 0
and int(worker_state.get("finalize_inflight", 0)) <= 0
and int(worker_state.get("finalize_pending", 0)) <= 0
)
def _record_engine_job_done(self, request_id: str) -> None:
self.engine_job_registry.mark_finished_and_remove(request_id)
self.scheduler_worker.record_external_job_done(request_id)
def _complete_engine_job(
self,
job: SchedulerPendingJob,
item: T2SFinishedItem,
*,
sample_rate: int,
audio_data: np.ndarray,
) -> None:
completion_bridge = self.scheduler_worker.completion_bridge
completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data)
completion_bridge.complete_job(
job,
runtime_request_id=job.engine_request_id,
runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate),
on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid),
)
def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None:
if not request_ids:
return
completion_bridge = self.scheduler_worker.completion_bridge
for request_id in request_ids:
job = self._get_engine_job(request_id)
if job is None:
continue
completion_bridge.fail_job(
job,
error=error,
on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid),
)
def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
for job in jobs:
job.prefill_ms += delta_ms
def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
for request_id in request_ids:
job = self._get_engine_job(request_id)
if job is not None:
job.merge_ms += delta_ms
def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
activate_request_ids: List[str] = []
for request_id in request_ids:
job = self._get_engine_job(request_id)
if job is None:
continue
if job.decode_steps == 0:
activate_request_ids.append(job.engine_request_id)
job.decode_ms += delta_ms
job.decode_steps += 1
for engine_request_id in activate_request_ids:
self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None)
def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None:
if not items:
return
enqueued_at = time.perf_counter()
tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items]
self.owner.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks)

View File

@ -0,0 +1,33 @@
from __future__ import annotations
from typing import Any, Dict
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner
class EngineRuntimeBridgeFacade:
def __init__(self, owner: Any) -> None:
self.owner = owner
@property
def engine_policy_arbiter(self):
return self.owner.engine_policy_arbiter
@staticmethod
def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]:
return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch)
def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]:
return self.engine_policy_arbiter.snapshot_state()
def _notify_engine_arbiter(self) -> None:
self.engine_policy_arbiter.notify()
def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed)
def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage()
self.owner.engine_dispatch_last_snapshot = dict(policy_snapshot)
return stage, reason, policy_snapshot, worker_state

View File

@ -0,0 +1,114 @@
from __future__ import annotations
import asyncio
from typing import Any, Dict, List
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, SchedulerFinalizeTask, SchedulerPendingJob
class EngineStageBridgeFacade:
def __init__(self, owner: Any) -> None:
self.owner = owner
@property
def engine_decode_runtime_owner(self):
return self.owner.engine_decode_runtime_owner
@property
def scheduler_worker(self):
return self.owner.scheduler_worker
@property
def engine_stage_coordinator(self):
return self.owner.engine_stage_coordinator
def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]:
return self.engine_decode_runtime_owner.snapshot_pending_queue_state()
def _refresh_engine_decode_runtime_state(self, last_event: str) -> None:
self.engine_decode_runtime_owner.refresh_state(last_event)
def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None:
if not snapshot:
return
if self.scheduler_worker.is_engine_decode_control_enabled():
return
self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot)
def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]:
return self.engine_decode_runtime_owner.snapshot_state()
def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None:
self.engine_decode_runtime_owner.enqueue_pending_job(job)
self.owner.engine_policy_arbiter.notify()
def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
return self.engine_decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch)
def _peek_queue_age_ms(self, queue_name: str) -> float:
return self.engine_stage_coordinator.peek_queue_age_ms(queue_name)
def _engine_has_pending_work(self) -> bool:
return self.engine_stage_coordinator.has_pending_work()
async def _prepare_state_via_engine_gpu_queue(
self,
*,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
engine_request_id: str | None,
) -> tuple[T2SRequestState, float, float]:
return await self.engine_stage_coordinator.prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=prepare_submit_at,
engine_request_id=engine_request_id,
)
def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
self.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks)
def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
return self.engine_stage_coordinator.take_engine_finalize_batch_nonblocking()
async def _enqueue_prepared_state_for_dispatch(
self,
*,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None,
done_future: asyncio.Future | None,
engine_request_id: str | None,
timeout_sec: float | None,
) -> EngineDispatchTask:
return await self.engine_stage_coordinator.enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=speed_factor,
sample_steps=sample_steps,
media_type=media_type,
prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=done_loop,
done_future=done_future,
engine_request_id=engine_request_id,
timeout_sec=timeout_sec,
)
def _run_engine_prepare_once(self) -> bool:
return self.engine_stage_coordinator.run_engine_prepare_once()
def _run_engine_finalize_once(self) -> bool:
return self.engine_stage_coordinator.run_engine_finalize_once()
def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
return self.engine_stage_coordinator.run_engine_dispatch_once(policy_snapshot, worker_state)
def _run_engine_decode_runtime_once(self) -> bool:
return self.engine_stage_coordinator.run_engine_decode_runtime_once()
def _run_engine_arbiter_loop(self) -> None:
return self.engine_stage_coordinator.run_engine_arbiter_loop()

View File

@ -1,401 +1,9 @@
from __future__ import annotations
import asyncio
from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple
import numpy as np
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, EngineDispatchTask, EngineRequestState, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerFinalizeTask, SchedulerPendingJob, SchedulerSubmitExecution
from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade
class EngineBridgeDelegates:
def _register_request_state(
self,
request_id: str,
api_mode: str,
backend: str,
media_type: str,
response_streaming: bool,
deadline_ts: float | None = None,
meta: Optional[Dict[str, Any]] = None,
) -> EngineRequestState:
return self.bridge_facade._register_request_state(
request_id=request_id,
api_mode=api_mode,
backend=backend,
media_type=media_type,
response_streaming=response_streaming,
deadline_ts=deadline_ts,
meta=meta,
)
def _update_request_state(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.bridge_facade._update_request_state(request_id, status, extra)
def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.bridge_facade._merge_request_state_profile(request_id, extra)
def _snapshot_engine_prepare_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_prepare_state()
def _snapshot_engine_finalize_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_finalize_state()
def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_dispatch_state()
def _register_engine_job(self, job: SchedulerPendingJob) -> None:
self.bridge_facade._register_engine_job(job)
def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
return self.bridge_facade._get_engine_job(request_id)
def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
return self.bridge_facade._pop_engine_job(request_id)
def _snapshot_engine_job_registry(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_job_registry()
def _is_engine_drained(self) -> bool:
return self.bridge_facade._is_engine_drained()
def _record_engine_job_done(self, request_id: str) -> None:
self.bridge_facade._record_engine_job_done(request_id)
def _complete_engine_job(
self,
job: SchedulerPendingJob,
item: T2SFinishedItem,
*,
sample_rate: int,
audio_data: np.ndarray,
) -> None:
self.bridge_facade._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None:
self.bridge_facade._fail_engine_jobs(request_ids, error)
def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None:
self.bridge_facade._add_engine_prefill_time(jobs, elapsed_s)
def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
self.bridge_facade._add_engine_merge_time(request_ids, elapsed_s)
def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
self.bridge_facade._add_engine_decode_time(request_ids, elapsed_s)
def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None:
self.bridge_facade._enqueue_engine_finished_items(items)
def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_decode_pending_queue_state()
@staticmethod
def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]:
return EngineBridgeFacade._summarize_active_batch(active_batch)
def _refresh_engine_decode_runtime_state(self, last_event: str) -> None:
self.bridge_facade._refresh_engine_decode_runtime_state(last_event)
def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None:
self.bridge_facade._update_engine_decode_runtime_state(snapshot)
def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_decode_runtime_state()
def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_arbiter_state()
def _notify_engine_arbiter(self) -> None:
self.bridge_facade._notify_engine_arbiter()
def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None:
self.bridge_facade._enqueue_engine_decode_pending_job(job)
def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
return self.bridge_facade._take_engine_decode_pending_jobs_nonblocking(wait_for_batch)
def _peek_queue_age_ms(self, queue_name: str) -> float:
return self.bridge_facade._peek_queue_age_ms(queue_name)
def _engine_has_pending_work(self) -> bool:
return self.bridge_facade._engine_has_pending_work()
async def _prepare_state_via_engine_gpu_queue(
self,
*,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
engine_request_id: str | None,
) -> tuple[T2SRequestState, float, float]:
return await self.bridge_facade._prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=prepare_submit_at,
engine_request_id=engine_request_id,
)
def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
self.bridge_facade._enqueue_worker_finished_for_finalize(tasks)
def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
return self.bridge_facade._take_engine_finalize_batch_nonblocking()
async def _enqueue_prepared_state_for_dispatch(
self,
*,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None,
done_future: asyncio.Future | None,
engine_request_id: str | None,
timeout_sec: float | None,
) -> EngineDispatchTask:
return await self.bridge_facade._enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=speed_factor,
sample_steps=sample_steps,
media_type=media_type,
prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=done_loop,
done_future=done_future,
engine_request_id=engine_request_id,
timeout_sec=timeout_sec,
)
def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
self.bridge_facade._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed)
def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
return self.bridge_facade._select_engine_stage()
def _run_engine_prepare_once(self) -> bool:
return self.bridge_facade._run_engine_prepare_once()
def _run_engine_finalize_once(self) -> bool:
return self.bridge_facade._run_engine_finalize_once()
def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
return self.bridge_facade._run_engine_dispatch_once(policy_snapshot, worker_state)
def _run_engine_decode_runtime_once(self) -> bool:
return self.bridge_facade._run_engine_decode_runtime_once()
def _run_engine_arbiter_loop(self) -> None:
self.bridge_facade._run_engine_arbiter_loop()
def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.bridge_facade._complete_request_state(request_id, extra)
def _fail_request_state(self, request_id: str, error: str) -> None:
self.bridge_facade._fail_request_state(request_id, error)
def _snapshot_request_registry(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_request_registry()
class EngineApiDelegates:
def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]:
return self.api_facade._collect_request_summaries(request_ids)
def _has_active_request(self, request_id: str) -> bool:
return self.api_facade._has_active_request(request_id)
@staticmethod
def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]:
return EngineApiFacade._build_request_meta(payload)
@staticmethod
def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float:
return EngineApiFacade._sum_profile_field(items, key)
def _build_direct_segment_trace(
self,
segment_texts: Sequence[str],
prepare_profiles: Sequence[Dict[str, Any]],
worker_profiles: Sequence[Dict[str, Any]],
) -> List[Dict[str, Any]]:
return self.api_facade._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles)
def _build_direct_scheduler_profile(self, **kwargs: Any) -> Dict[str, Any]:
return self.api_facade._build_direct_scheduler_profile(**kwargs)
def _build_legacy_direct_profile(self, **kwargs: Any) -> Dict[str, Any]:
return self.api_facade._build_legacy_direct_profile(**kwargs)
def _build_scheduler_submit_profile(self, **kwargs: Any) -> Dict[str, Any]:
return self.api_facade._build_scheduler_submit_profile(**kwargs)
@staticmethod
def _format_ms_header(value: Any) -> str:
return EngineApiFacade._format_ms_header(value)
def _build_scheduler_submit_headers(
self,
*,
request_id: str,
media_type: str,
sample_rate: int,
profile: Dict[str, Any],
) -> Dict[str, str]:
return self.api_facade._build_scheduler_submit_headers(
request_id=request_id,
media_type=media_type,
sample_rate=sample_rate,
profile=profile,
)
def _build_scheduler_debug_request_profile(self, **kwargs: Any) -> Dict[str, Any]:
return self.api_facade._build_scheduler_debug_request_profile(**kwargs)
@staticmethod
def _build_scheduler_debug_batch_profile(**kwargs: Any) -> Dict[str, Any]:
return EngineApiFacade._build_scheduler_debug_batch_profile(**kwargs)
def _normalize_lang(self, value: str | None) -> str | None:
return self.api_facade._normalize_lang(value)
@staticmethod
def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]:
return EngineApiFacade._aggregate_numeric_dicts(items)
def _apply_default_reference(self, req: dict) -> dict:
return self.api_facade._apply_default_reference(req)
def check_params(self, req: dict) -> Optional[str]:
return self.api_facade.check_params(req)
@staticmethod
def _base_request_defaults() -> Dict[str, Any]:
return EngineApiFacade._base_request_defaults()
def _normalize_engine_request(
self,
payload: dict | NormalizedEngineRequest,
*,
request_id: str | None = None,
normalize_streaming: bool = False,
error_prefix: str = "request 参数非法: ",
) -> NormalizedEngineRequest:
return self.api_facade._normalize_engine_request(
payload,
request_id=request_id,
normalize_streaming=normalize_streaming,
error_prefix=error_prefix,
)
@staticmethod
def _normalize_streaming_mode(req: dict) -> dict:
return EngineApiFacade._normalize_streaming_mode(req)
@staticmethod
def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool:
return EngineApiFacade._is_aux_ref_enabled(aux_ref_audio_paths)
def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]:
return self.api_facade._select_direct_backend(normalized)
def _iter_legacy_direct_tts_bytes(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> Generator[bytes, None, None]:
return self.api_facade._iter_legacy_direct_tts_bytes(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool:
return self.api_facade._should_use_scheduler_backend_for_direct(req)
def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]:
return self.api_facade._segment_direct_text(normalized)
def _build_segment_request(
self,
normalized: NormalizedEngineRequest,
*,
request_id: str,
text: str,
) -> NormalizedEngineRequest:
return self.api_facade._build_segment_request(normalized, request_id=request_id, text=text)
async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution:
return await self.api_facade._run_direct_tts_via_scheduler(normalized)
def _run_legacy_direct_tts_blocking(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> DirectTTSExecution:
return self.api_facade._run_legacy_direct_tts_blocking(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
async def _run_direct_tts_via_legacy_backend(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> DirectTTSExecution:
return await self.api_facade._run_direct_tts_via_legacy_backend(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
class EngineRuntimeDelegates:
@staticmethod
def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None:
return EngineRuntimeFacade._safe_component_snapshot(component)
def _build_stage_counters(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
return self.runtime_facade._build_stage_counters(request_registry, worker_state)
def _build_engine_policy_snapshot(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
return self.runtime_facade._build_engine_policy_snapshot(request_registry, worker_state)
async def _wait_for_engine_policy_admission(
self,
*,
request_id: str | None,
timeout_sec: float | None,
) -> tuple[float, Dict[str, Any]]:
return await self.engine_policy_arbiter.wait_for_policy_admission(
request_id=request_id,
timeout_sec=timeout_sec,
)
def _build_stage_summary(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
return self.runtime_facade._build_stage_summary(request_registry, worker_state)
def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None:
self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec)
from GPT_SoVITS.TTS_infer_pack.unified_engine_api_delegates import EngineApiDelegates
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_delegates import EngineBridgeDelegates
from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime_delegates import EngineRuntimeDelegates
__all__ = [
"EngineApiDelegates",
"EngineBridgeDelegates",
"EngineRuntimeDelegates",
]

View File

@ -0,0 +1,92 @@
from __future__ import annotations
from typing import Any, Callable, Dict
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner, EngineTaskQueueOwner
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_executor import EngineStageExecutor
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
class EngineStageOrchestrator:
def __init__(
self,
*,
executor: EngineStageExecutor,
scheduler_worker: UnifiedSchedulerWorker,
prepare_queue_owner: EngineTaskQueueOwner,
finalize_queue_owner: EngineTaskQueueOwner,
dispatch_queue_owner: EngineTaskQueueOwner,
decode_runtime_owner: EngineDecodeRuntimeOwner,
snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]],
) -> None:
self.executor = executor
self.scheduler_worker = scheduler_worker
self.prepare_queue_owner = prepare_queue_owner
self.finalize_queue_owner = finalize_queue_owner
self.dispatch_queue_owner = dispatch_queue_owner
self.decode_runtime_owner = decode_runtime_owner
self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state
self._select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]] | None = None
self._mark_arbiter_tick: Callable[[str, str, bool], None] | None = None
self._wait_arbiter: Callable[[], None] | None = None
def bind_arbiter(
self,
*,
notify_arbiter: Callable[[], None],
select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]],
mark_arbiter_tick: Callable[[str, str, bool], None],
wait_arbiter: Callable[[], None],
) -> None:
self.executor.bind_notify_arbiter(notify_arbiter)
self._select_stage = select_stage
self._mark_arbiter_tick = mark_arbiter_tick
self._wait_arbiter = wait_arbiter
def peek_queue_age_ms(self, queue_name: str) -> float:
if queue_name == "prepare":
return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time")
if queue_name == "finalize":
return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time")
if queue_name == "decode_runtime_pending":
return self.decode_runtime_owner.pending_age_ms()
return self.dispatch_queue_owner.peek_oldest_age_ms("enqueue_time")
def has_pending_work(self) -> bool:
if self.scheduler_worker.is_engine_decode_control_enabled():
if self.decode_runtime_owner.has_pending_jobs():
return True
if self.scheduler_worker.is_engine_decode_control_enabled() and self.snapshot_engine_decode_runtime_state().get(
"active_request_count", 0
) > 0:
return True
if self.prepare_queue_owner.has_items():
return True
if self.finalize_queue_owner.has_items():
return True
return self.dispatch_queue_owner.has_items()
def run_engine_arbiter_loop(self) -> None:
if self._select_stage is None or self._mark_arbiter_tick is None or self._wait_arbiter is None:
raise RuntimeError("arbiter callbacks are not bound")
while True:
if not self.has_pending_work():
self._mark_arbiter_tick("idle", "no_pending_work", True)
self._wait_arbiter()
continue
stage, reason, policy_snapshot, worker_state = self._select_stage()
policy_allowed = bool(policy_snapshot.get("allowed", True))
executed = False
if stage == "prepare":
executed = self.executor.run_engine_prepare_once()
elif stage == "finalize":
executed = self.executor.run_engine_finalize_once()
elif stage == "decode_dispatch":
executed = self.executor.run_engine_dispatch_once(policy_snapshot, worker_state)
elif stage == "decode_runtime":
executed = self.executor.run_engine_decode_runtime_once()
if not executed:
self._mark_arbiter_tick("idle", f"{stage}_not_ready", policy_allowed)
self._wait_arbiter()
continue
self._mark_arbiter_tick(stage, reason, policy_allowed)

View File

@ -0,0 +1,46 @@
from __future__ import annotations
from typing import Any, Dict
from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade
class EngineRuntimeDelegates:
@staticmethod
def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None:
return EngineRuntimeFacade._safe_component_snapshot(component)
def _build_stage_counters(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
return self.runtime_facade._build_stage_counters(request_registry, worker_state)
def _build_engine_policy_snapshot(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
return self.runtime_facade._build_engine_policy_snapshot(request_registry, worker_state)
async def _wait_for_engine_policy_admission(
self,
*,
request_id: str | None,
timeout_sec: float | None,
) -> tuple[float, Dict[str, Any]]:
return await self.engine_policy_arbiter.wait_for_policy_admission(
request_id=request_id,
timeout_sec=timeout_sec,
)
def _build_stage_summary(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
return self.runtime_facade._build_stage_summary(request_registry, worker_state)
def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None:
self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec)

View File

@ -1,20 +1,19 @@
from __future__ import annotations
import asyncio
import time
from typing import Any, Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import (
EngineDecodeRuntimeOwner,
EngineDispatchTask,
EngineGpuPrepareTask,
EngineStatus,
EngineTaskQueueOwner,
SchedulerFinalizeTask,
SchedulerPendingJob,
)
from GPT_SoVITS.TTS_infer_pack.unified_engine_orchestration import EngineStageOrchestrator
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_executor import EngineStageExecutor
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
@ -42,29 +41,36 @@ class EngineStageCoordinator:
snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]],
snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]],
) -> None:
self.tts = tts
self.scheduler_worker = scheduler_worker
self.prepare_queue_owner = prepare_queue_owner
self.finalize_queue_owner = finalize_queue_owner
self.dispatch_queue_owner = dispatch_queue_owner
self.decode_runtime_owner = decode_runtime_owner
self.update_request_state = update_request_state
self.merge_request_state_profile = merge_request_state_profile
self.fail_request_state = fail_request_state
self.get_engine_job = get_engine_job
self.register_engine_job = register_engine_job
self.fail_engine_jobs = fail_engine_jobs
self.complete_engine_job = complete_engine_job
self.add_engine_prefill_time = add_engine_prefill_time
self.add_engine_merge_time = add_engine_merge_time
self.add_engine_decode_time = add_engine_decode_time
self.enqueue_engine_finished_items = enqueue_engine_finished_items
self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state
self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state
self._notify_arbiter: Callable[[], None] | None = None
self._select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]] | None = None
self._mark_arbiter_tick: Callable[[str, str, bool], None] | None = None
self._wait_arbiter: Callable[[], None] | None = None
self.executor = EngineStageExecutor(
tts=tts,
scheduler_worker=scheduler_worker,
prepare_queue_owner=prepare_queue_owner,
finalize_queue_owner=finalize_queue_owner,
dispatch_queue_owner=dispatch_queue_owner,
decode_runtime_owner=decode_runtime_owner,
update_request_state=update_request_state,
merge_request_state_profile=merge_request_state_profile,
fail_request_state=fail_request_state,
get_engine_job=get_engine_job,
register_engine_job=register_engine_job,
fail_engine_jobs=fail_engine_jobs,
complete_engine_job=complete_engine_job,
add_engine_prefill_time=add_engine_prefill_time,
add_engine_merge_time=add_engine_merge_time,
add_engine_decode_time=add_engine_decode_time,
enqueue_engine_finished_items=enqueue_engine_finished_items,
snapshot_engine_dispatch_state=snapshot_engine_dispatch_state,
snapshot_engine_decode_runtime_state=snapshot_engine_decode_runtime_state,
)
self.orchestrator = EngineStageOrchestrator(
executor=self.executor,
scheduler_worker=scheduler_worker,
prepare_queue_owner=prepare_queue_owner,
finalize_queue_owner=finalize_queue_owner,
dispatch_queue_owner=dispatch_queue_owner,
decode_runtime_owner=decode_runtime_owner,
snapshot_engine_decode_runtime_state=snapshot_engine_decode_runtime_state,
)
def bind_arbiter(
self,
@ -74,57 +80,12 @@ class EngineStageCoordinator:
mark_arbiter_tick: Callable[[str, str, bool], None],
wait_arbiter: Callable[[], None],
) -> None:
self._notify_arbiter = notify_arbiter
self._select_stage = select_stage
self._mark_arbiter_tick = mark_arbiter_tick
self._wait_arbiter = wait_arbiter
def notify_arbiter(self) -> None:
if self._notify_arbiter is not None:
self._notify_arbiter()
@staticmethod
def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None:
if future.done():
return
future.set_exception(error)
def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None:
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error)
except RuntimeError:
pass
@staticmethod
def _resolve_prepare_future(
future: asyncio.Future,
payload: tuple[T2SRequestState, float, float],
) -> None:
if future.done():
return
future.set_result(payload)
def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None:
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error)
except RuntimeError:
pass
def _notify_prepare_result(
self,
task: EngineGpuPrepareTask,
payload: tuple[T2SRequestState, float, float],
) -> None:
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload)
except RuntimeError:
pass
self.orchestrator.bind_arbiter(
notify_arbiter=notify_arbiter,
select_stage=select_stage,
mark_arbiter_tick=mark_arbiter_tick,
wait_arbiter=wait_arbiter,
)
async def prepare_state_via_engine_gpu_queue(
self,
@ -133,59 +94,17 @@ class EngineStageCoordinator:
prepare_submit_at: float,
engine_request_id: str | None,
) -> tuple[T2SRequestState, float, float]:
cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at)
if engine_request_id not in [None, ""]:
self.update_request_state(
str(engine_request_id),
EngineStatus.GPU_PREPARING,
{
"prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms),
"prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms),
"text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms),
"text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms),
},
)
loop = asyncio.get_running_loop()
done_future = loop.create_future()
task = EngineGpuPrepareTask(
request_id=spec.request_id,
cpu_stage=cpu_stage,
done_loop=loop,
done_future=done_future,
engine_request_id=engine_request_id or spec.request_id,
enqueue_time=time.perf_counter(),
return await self.executor.prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=prepare_submit_at,
engine_request_id=engine_request_id,
)
self.prepare_queue_owner.enqueue(task)
self.notify_arbiter()
state, prepare_exec_started_at, prepare_exec_finished_at = await done_future
return state, prepare_exec_started_at, prepare_exec_finished_at
def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
if not tasks:
return
for task in tasks:
job = self.get_engine_job(task.request_id)
if job is not None:
self.update_request_state(
job.engine_request_id,
EngineStatus.READY_FOR_FINALIZE,
{
"finish_reason": task.item.finish_reason,
"semantic_len": int(task.item.semantic_tokens.shape[0]),
"finish_idx": int(task.item.finish_idx),
},
)
self.finalize_queue_owner.enqueue_many(tasks)
self.notify_arbiter()
self.executor.enqueue_worker_finished_for_finalize(tasks)
def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
finalize_policy = self.scheduler_worker.get_finalize_batch_policy()
return self.finalize_queue_owner.take_finalize_batch(
finalize_mode=str(finalize_policy.get("finalize_mode", "async")),
batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)),
batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)),
use_vocoder=bool(self.tts.configs.use_vocoder),
)
return self.executor.take_engine_finalize_batch_nonblocking()
async def enqueue_prepared_state_for_dispatch(
self,
@ -201,220 +120,36 @@ class EngineStageCoordinator:
engine_request_id: str | None,
timeout_sec: float | None,
) -> EngineDispatchTask:
task = EngineDispatchTask(
request_id=state.request_id,
return await self.executor.enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=float(speed_factor),
sample_steps=int(sample_steps),
speed_factor=speed_factor,
sample_steps=sample_steps,
media_type=media_type,
prepare_wall_ms=float(prepare_wall_ms),
prepare_profile_total_ms=float(prepare_profile_total_ms),
prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=done_loop,
done_future=done_future,
engine_request_id=engine_request_id or state.request_id,
engine_request_id=engine_request_id,
timeout_sec=timeout_sec,
enqueue_time=time.perf_counter(),
)
self.dispatch_queue_owner.enqueue(task)
self.notify_arbiter()
self.merge_request_state_profile(
task.engine_request_id or task.request_id,
{
"engine_dispatch_queue_depth_on_enqueue": int(
self.snapshot_engine_dispatch_state()["waiting_count"]
),
},
)
return task
def peek_queue_age_ms(self, queue_name: str) -> float:
if queue_name == "prepare":
return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time")
if queue_name == "finalize":
return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time")
if queue_name == "decode_runtime_pending":
return self.decode_runtime_owner.pending_age_ms()
return self.dispatch_queue_owner.peek_oldest_age_ms("enqueue_time")
return self.orchestrator.peek_queue_age_ms(queue_name)
def has_pending_work(self) -> bool:
if self.scheduler_worker.is_engine_decode_control_enabled():
if self.decode_runtime_owner.has_pending_jobs():
return True
if self.scheduler_worker.is_engine_decode_control_enabled() and self.snapshot_engine_decode_runtime_state().get(
"active_request_count", 0
) > 0:
return True
if self.prepare_queue_owner.has_items():
return True
if self.finalize_queue_owner.has_items():
return True
return self.dispatch_queue_owner.has_items()
return self.orchestrator.has_pending_work()
def run_engine_prepare_once(self) -> bool:
task = self.prepare_queue_owner.pop_left()
if task is None:
return False
queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0)
try:
state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run(
self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage)
)
state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms)
if task.engine_request_id not in [None, ""]:
self.merge_request_state_profile(
str(task.engine_request_id),
{"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)},
)
self.prepare_queue_owner.mark_completed(1)
self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at))
return True
except Exception as exc:
task.error = str(exc)
self.fail_request_state(task.engine_request_id or task.request_id, str(exc))
self._notify_prepare_error(task, exc)
return True
return self.executor.run_engine_prepare_once()
def run_engine_finalize_once(self) -> bool:
tasks = self.take_engine_finalize_batch_nonblocking()
if not tasks:
return False
self.scheduler_worker.begin_finalize_execution(len(tasks))
try:
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
for task in tasks:
job = self.get_engine_job(task.request_id)
if job is None:
continue
jobs_and_items.append((job, task.item))
if not jobs_and_items:
return False
now = time.perf_counter()
for task in tasks:
job = self.get_engine_job(task.request_id)
if job is not None:
job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0)
for job, item in jobs_and_items:
self.update_request_state(
job.engine_request_id,
EngineStatus.FINALIZING,
{
"finish_reason": item.finish_reason,
"semantic_len": int(item.semantic_tokens.shape[0]),
},
)
synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items)
for job, _ in jobs_and_items:
job.synth_ms += float(synth_ms)
for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results):
self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
except Exception as exc:
self.fail_engine_jobs([task.request_id for task in tasks], str(exc))
finally:
self.scheduler_worker.end_finalize_execution(len(tasks))
self.finalize_queue_owner.mark_completed(len(tasks), notify=True)
return True
return self.executor.run_engine_finalize_once()
def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
if not bool(policy_snapshot.get("allowed", True)):
return False
dispatch_task = self.dispatch_queue_owner.pop_left()
if dispatch_task is None:
return False
dispatched_at = time.perf_counter()
dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0)
dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms)
dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms)
dispatch_task.engine_policy_snapshot = dict(policy_snapshot)
try:
worker_job = self.scheduler_worker.submit(
state=dispatch_task.state,
speed_factor=dispatch_task.speed_factor,
sample_steps=dispatch_task.sample_steps,
media_type=dispatch_task.media_type,
prepare_wall_ms=dispatch_task.prepare_wall_ms,
prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms,
done_loop=dispatch_task.done_loop,
done_future=dispatch_task.done_future,
engine_request_id=dispatch_task.engine_request_id,
timeout_sec=dispatch_task.timeout_sec,
skip_capacity_wait=True,
admission_wait_ms_override=0.0,
admission_snapshot_override=dict(worker_state),
engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms,
engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms,
enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(),
)
dispatch_task.worker_job = worker_job
self.register_engine_job(worker_job)
if self.scheduler_worker.is_engine_decode_control_enabled():
self.decode_runtime_owner.enqueue_pending_job(worker_job)
self.notify_arbiter()
self.dispatch_queue_owner.mark_completed(1)
return True
except Exception as exc:
dispatch_task.error = str(exc)
self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc))
self._notify_dispatch_error(dispatch_task, exc)
return True
return self.executor.run_engine_dispatch_once(policy_snapshot, worker_state)
def run_engine_decode_runtime_once(self) -> bool:
if not self.scheduler_worker.is_engine_decode_control_enabled():
return False
runtime_state = self.snapshot_engine_decode_runtime_state()
pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking(
wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0
)
result = self.scheduler_worker.execute_decode_cycle(
pending_jobs=pending_jobs,
active_batch=self.decode_runtime_owner.get_active_batch(),
external_bookkeeping=True,
)
prefill_phase = dict(result.get("prefill_phase") or {})
if prefill_phase.get("error"):
self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error")))
else:
prefill_jobs = list(prefill_phase.get("pending_jobs") or [])
self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0)))
self.add_engine_merge_time(
[] if result.get("active_batch") is None else list(result["active_batch"].request_ids),
float(prefill_phase.get("merge_elapsed_s", 0.0)),
)
self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or []))
decode_phase = dict(result.get("decode_phase") or {})
if decode_phase.get("error"):
self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error")))
else:
self.add_engine_decode_time(
list(decode_phase.get("request_ids") or []),
float(decode_phase.get("decode_elapsed_s", 0.0)),
)
self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or []))
self.decode_runtime_owner.set_active_batch(result.get("active_batch"))
if result.get("executed", False):
self.decode_runtime_owner.refresh_state("engine_decode_cycle")
return bool(result.get("executed", False))
return self.executor.run_engine_decode_runtime_once()
def run_engine_arbiter_loop(self) -> None:
if self._select_stage is None or self._mark_arbiter_tick is None or self._wait_arbiter is None:
raise RuntimeError("arbiter callbacks are not bound")
while True:
if not self.has_pending_work():
self._mark_arbiter_tick("idle", "no_pending_work", True)
self._wait_arbiter()
continue
stage, reason, policy_snapshot, worker_state = self._select_stage()
policy_allowed = bool(policy_snapshot.get("allowed", True))
executed = False
if stage == "prepare":
executed = self.run_engine_prepare_once()
elif stage == "finalize":
executed = self.run_engine_finalize_once()
elif stage == "decode_dispatch":
executed = self.run_engine_dispatch_once(policy_snapshot, worker_state)
elif stage == "decode_runtime":
executed = self.run_engine_decode_runtime_once()
if not executed:
self._mark_arbiter_tick("idle", f"{stage}_not_ready", policy_allowed)
self._wait_arbiter()
continue
self._mark_arbiter_tick(stage, reason, policy_allowed)
self.orchestrator.run_engine_arbiter_loop()

View File

@ -0,0 +1,358 @@
from __future__ import annotations
import asyncio
import time
from typing import Any, Callable, Dict, List, Optional
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import (
EngineDecodeRuntimeOwner,
EngineDispatchTask,
EngineGpuPrepareTask,
EngineStatus,
EngineTaskQueueOwner,
SchedulerFinalizeTask,
SchedulerPendingJob,
)
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
class EngineStageExecutor:
def __init__(
self,
*,
tts: TTS,
scheduler_worker: UnifiedSchedulerWorker,
prepare_queue_owner: EngineTaskQueueOwner,
finalize_queue_owner: EngineTaskQueueOwner,
dispatch_queue_owner: EngineTaskQueueOwner,
decode_runtime_owner: EngineDecodeRuntimeOwner,
update_request_state: Callable[[str, str, Optional[Dict[str, Any]]], None],
merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None],
fail_request_state: Callable[[str, str], None],
get_engine_job: Callable[[str], SchedulerPendingJob | None],
register_engine_job: Callable[[SchedulerPendingJob], None],
fail_engine_jobs: Callable[[List[str], str], None],
complete_engine_job: Callable[..., None],
add_engine_prefill_time: Callable[[List[SchedulerPendingJob], float], None],
add_engine_merge_time: Callable[[List[str], float], None],
add_engine_decode_time: Callable[[List[str], float], None],
enqueue_engine_finished_items: Callable[[List[T2SFinishedItem]], None],
snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]],
snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]],
) -> None:
self.tts = tts
self.scheduler_worker = scheduler_worker
self.prepare_queue_owner = prepare_queue_owner
self.finalize_queue_owner = finalize_queue_owner
self.dispatch_queue_owner = dispatch_queue_owner
self.decode_runtime_owner = decode_runtime_owner
self.update_request_state = update_request_state
self.merge_request_state_profile = merge_request_state_profile
self.fail_request_state = fail_request_state
self.get_engine_job = get_engine_job
self.register_engine_job = register_engine_job
self.fail_engine_jobs = fail_engine_jobs
self.complete_engine_job = complete_engine_job
self.add_engine_prefill_time = add_engine_prefill_time
self.add_engine_merge_time = add_engine_merge_time
self.add_engine_decode_time = add_engine_decode_time
self.enqueue_engine_finished_items = enqueue_engine_finished_items
self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state
self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state
self._notify_arbiter: Callable[[], None] | None = None
def bind_notify_arbiter(self, notify_arbiter: Callable[[], None]) -> None:
self._notify_arbiter = notify_arbiter
def notify_arbiter(self) -> None:
if self._notify_arbiter is not None:
self._notify_arbiter()
@staticmethod
def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None:
if future.done():
return
future.set_exception(error)
def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None:
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error)
except RuntimeError:
pass
@staticmethod
def _resolve_prepare_future(
future: asyncio.Future,
payload: tuple[T2SRequestState, float, float],
) -> None:
if future.done():
return
future.set_result(payload)
def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None:
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error)
except RuntimeError:
pass
def _notify_prepare_result(
self,
task: EngineGpuPrepareTask,
payload: tuple[T2SRequestState, float, float],
) -> None:
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload)
except RuntimeError:
pass
async def prepare_state_via_engine_gpu_queue(
self,
*,
spec: Any,
prepare_submit_at: float,
engine_request_id: str | None,
) -> tuple[T2SRequestState, float, float]:
cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at)
if engine_request_id not in [None, ""]:
self.update_request_state(
str(engine_request_id),
EngineStatus.GPU_PREPARING,
{
"prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms),
"prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms),
"text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms),
"text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms),
},
)
loop = asyncio.get_running_loop()
done_future = loop.create_future()
task = EngineGpuPrepareTask(
request_id=spec.request_id,
cpu_stage=cpu_stage,
done_loop=loop,
done_future=done_future,
engine_request_id=engine_request_id or spec.request_id,
enqueue_time=time.perf_counter(),
)
self.prepare_queue_owner.enqueue(task)
self.notify_arbiter()
return await done_future
def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
if not tasks:
return
for task in tasks:
job = self.get_engine_job(task.request_id)
if job is not None:
self.update_request_state(
job.engine_request_id,
EngineStatus.READY_FOR_FINALIZE,
{
"finish_reason": task.item.finish_reason,
"semantic_len": int(task.item.semantic_tokens.shape[0]),
"finish_idx": int(task.item.finish_idx),
},
)
self.finalize_queue_owner.enqueue_many(tasks)
self.notify_arbiter()
def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
finalize_policy = self.scheduler_worker.get_finalize_batch_policy()
return self.finalize_queue_owner.take_finalize_batch(
finalize_mode=str(finalize_policy.get("finalize_mode", "async")),
batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)),
batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)),
use_vocoder=bool(self.tts.configs.use_vocoder),
)
async def enqueue_prepared_state_for_dispatch(
self,
*,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None,
done_future: asyncio.Future | None,
engine_request_id: str | None,
timeout_sec: float | None,
) -> EngineDispatchTask:
task = EngineDispatchTask(
request_id=state.request_id,
state=state,
speed_factor=float(speed_factor),
sample_steps=int(sample_steps),
media_type=media_type,
prepare_wall_ms=float(prepare_wall_ms),
prepare_profile_total_ms=float(prepare_profile_total_ms),
done_loop=done_loop,
done_future=done_future,
engine_request_id=engine_request_id or state.request_id,
timeout_sec=timeout_sec,
enqueue_time=time.perf_counter(),
)
self.dispatch_queue_owner.enqueue(task)
self.notify_arbiter()
self.merge_request_state_profile(
task.engine_request_id or task.request_id,
{
"engine_dispatch_queue_depth_on_enqueue": int(
self.snapshot_engine_dispatch_state()["waiting_count"]
),
},
)
return task
def run_engine_prepare_once(self) -> bool:
task = self.prepare_queue_owner.pop_left()
if task is None:
return False
queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0)
try:
state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run(
self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage)
)
state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms)
if task.engine_request_id not in [None, ""]:
self.merge_request_state_profile(
str(task.engine_request_id),
{"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)},
)
self.prepare_queue_owner.mark_completed(1)
self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at))
return True
except Exception as exc:
task.error = str(exc)
self.fail_request_state(task.engine_request_id or task.request_id, str(exc))
self._notify_prepare_error(task, exc)
return True
def run_engine_finalize_once(self) -> bool:
tasks = self.take_engine_finalize_batch_nonblocking()
if not tasks:
return False
self.scheduler_worker.begin_finalize_execution(len(tasks))
try:
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
for task in tasks:
job = self.get_engine_job(task.request_id)
if job is None:
continue
jobs_and_items.append((job, task.item))
if not jobs_and_items:
return False
now = time.perf_counter()
for task in tasks:
job = self.get_engine_job(task.request_id)
if job is not None:
job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0)
for job, item in jobs_and_items:
self.update_request_state(
job.engine_request_id,
EngineStatus.FINALIZING,
{
"finish_reason": item.finish_reason,
"semantic_len": int(item.semantic_tokens.shape[0]),
},
)
synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items)
for job, _ in jobs_and_items:
job.synth_ms += float(synth_ms)
for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results):
self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
except Exception as exc:
self.fail_engine_jobs([task.request_id for task in tasks], str(exc))
finally:
self.scheduler_worker.end_finalize_execution(len(tasks))
self.finalize_queue_owner.mark_completed(len(tasks), notify=True)
return True
def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
if not bool(policy_snapshot.get("allowed", True)):
return False
dispatch_task = self.dispatch_queue_owner.pop_left()
if dispatch_task is None:
return False
dispatched_at = time.perf_counter()
dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0)
dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms)
dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms)
dispatch_task.engine_policy_snapshot = dict(policy_snapshot)
try:
worker_job = self.scheduler_worker.submit(
state=dispatch_task.state,
speed_factor=dispatch_task.speed_factor,
sample_steps=dispatch_task.sample_steps,
media_type=dispatch_task.media_type,
prepare_wall_ms=dispatch_task.prepare_wall_ms,
prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms,
done_loop=dispatch_task.done_loop,
done_future=dispatch_task.done_future,
engine_request_id=dispatch_task.engine_request_id,
timeout_sec=dispatch_task.timeout_sec,
skip_capacity_wait=True,
admission_wait_ms_override=0.0,
admission_snapshot_override=dict(worker_state),
engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms,
engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms,
enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(),
)
dispatch_task.worker_job = worker_job
self.register_engine_job(worker_job)
if self.scheduler_worker.is_engine_decode_control_enabled():
self.decode_runtime_owner.enqueue_pending_job(worker_job)
self.notify_arbiter()
self.dispatch_queue_owner.mark_completed(1)
return True
except Exception as exc:
dispatch_task.error = str(exc)
self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc))
self._notify_dispatch_error(dispatch_task, exc)
return True
def run_engine_decode_runtime_once(self) -> bool:
if not self.scheduler_worker.is_engine_decode_control_enabled():
return False
runtime_state = self.snapshot_engine_decode_runtime_state()
pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking(
wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0
)
result = self.scheduler_worker.execute_decode_cycle(
pending_jobs=pending_jobs,
active_batch=self.decode_runtime_owner.get_active_batch(),
external_bookkeeping=True,
)
prefill_phase = dict(result.get("prefill_phase") or {})
if prefill_phase.get("error"):
self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error")))
else:
prefill_jobs = list(prefill_phase.get("pending_jobs") or [])
self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0)))
self.add_engine_merge_time(
[] if result.get("active_batch") is None else list(result["active_batch"].request_ids),
float(prefill_phase.get("merge_elapsed_s", 0.0)),
)
self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or []))
decode_phase = dict(result.get("decode_phase") or {})
if decode_phase.get("error"):
self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error")))
else:
self.add_engine_decode_time(
list(decode_phase.get("request_ids") or []),
float(decode_phase.get("decode_elapsed_s", 0.0)),
)
self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or []))
self.decode_runtime_owner.set_active_batch(result.get("active_batch"))
if result.get("executed", False):
self.decode_runtime_owner.refresh_state("engine_decode_cycle")
return bool(result.get("executed", False))