Add unified engine components and API for enhanced TTS processing

Introduce multiple new modules including unified_engine_api, unified_engine_audio, unified_engine_bridge, unified_engine_builder, unified_engine_components, unified_engine_delegates, and unified_engine_runtime. These additions provide a comprehensive framework for managing TTS requests, audio packing, and engine state management, significantly improving the architecture and maintainability of the TTS system. The new structure supports asynchronous operations and enhances overall performance through better request handling and processing capabilities.
This commit is contained in:
baicai-1145 2026-03-11 08:32:56 +08:00
parent 06d6b67f73
commit d1ec7d9e54
10 changed files with 5724 additions and 4791 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,106 @@
from __future__ import annotations
import subprocess
import threading
import wave
from io import BytesIO
import numpy as np
import soundfile as sf
import torch
def set_scheduler_seed(seed: int):
if seed in ["", None]:
return
seed = int(seed)
if seed < 0:
return
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
def handle_pack_ogg():
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
audio_file.write(data)
stack_size = 4096 * 4096
try:
threading.stack_size(stack_size)
pack_ogg_thread = threading.Thread(target=handle_pack_ogg)
pack_ogg_thread.start()
pack_ogg_thread.join()
except (RuntimeError, ValueError):
handle_pack_ogg()
return io_buffer
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer.write(data.tobytes())
return io_buffer
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer = BytesIO()
sf.write(io_buffer, data, rate, format="wav")
return io_buffer
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
process = subprocess.Popen(
[
"ffmpeg",
"-f",
"s16le",
"-ar",
str(rate),
"-ac",
"1",
"-i",
"pipe:0",
"-c:a",
"aac",
"-b:a",
"192k",
"-vn",
"-f",
"adts",
"pipe:1",
],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, _ = process.communicate(input=data.tobytes())
io_buffer.write(out)
return io_buffer
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str):
if media_type == "ogg":
io_buffer = pack_ogg(io_buffer, data, rate)
elif media_type == "aac":
io_buffer = pack_aac(io_buffer, data, rate)
elif media_type == "wav":
io_buffer = pack_wav(io_buffer, data, rate)
else:
io_buffer = pack_raw(io_buffer, data, rate)
io_buffer.seek(0)
return io_buffer
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
wav_buf = BytesIO()
with wave.open(wav_buf, "wb") as vfout:
vfout.setnchannels(channels)
vfout.setsampwidth(sample_width)
vfout.setframerate(sample_rate)
vfout.writeframes(frame_input)
wav_buf.seek(0)
return wav_buf.read()

View File

@ -0,0 +1,310 @@
from __future__ import annotations
import asyncio
import time
from typing import Any, Dict, List, Optional
import numpy as np
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner, EngineDispatchTask, EngineRequestState, EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob
class EngineBridgeFacade:
def __init__(self, owner: Any) -> None:
self.owner = owner
@property
def request_registry(self):
return self.owner.request_registry
@property
def engine_prepare_queue_owner(self):
return self.owner.engine_prepare_queue_owner
@property
def engine_finalize_queue_owner(self):
return self.owner.engine_finalize_queue_owner
@property
def engine_dispatch_queue_owner(self):
return self.owner.engine_dispatch_queue_owner
@property
def engine_decode_runtime_owner(self):
return self.owner.engine_decode_runtime_owner
@property
def engine_job_registry(self):
return self.owner.engine_job_registry
@property
def scheduler_worker(self):
return self.owner.scheduler_worker
@property
def engine_stage_coordinator(self):
return self.owner.engine_stage_coordinator
@property
def engine_policy_arbiter(self):
return self.owner.engine_policy_arbiter
def _register_request_state(
self,
request_id: str,
api_mode: str,
backend: str,
media_type: str,
response_streaming: bool,
deadline_ts: float | None = None,
meta: Optional[Dict[str, Any]] = None,
) -> EngineRequestState:
return self.request_registry.register(
request_id=request_id,
api_mode=api_mode,
backend=backend,
media_type=media_type,
response_streaming=response_streaming,
deadline_ts=deadline_ts,
meta=meta,
)
def _update_request_state(
self,
request_id: str,
status: str,
extra: Optional[Dict[str, Any]] = None,
) -> None:
self.request_registry.update(request_id, status, extra)
def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.request_registry.merge_profile(request_id, extra)
def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.request_registry.complete(request_id, extra)
def _fail_request_state(self, request_id: str, error: str) -> None:
self.request_registry.fail(request_id, error)
def _snapshot_request_registry(self) -> Dict[str, Any]:
return self.request_registry.snapshot()
def _snapshot_engine_prepare_state(self) -> Dict[str, Any]:
return self.engine_prepare_queue_owner.snapshot(max_request_ids=16)
def _snapshot_engine_finalize_state(self) -> Dict[str, Any]:
return self.engine_finalize_queue_owner.snapshot(max_request_ids=16)
def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]:
return self.engine_dispatch_queue_owner.snapshot(
max_request_ids=16,
extra={"last_policy_snapshot": dict(self.owner.engine_dispatch_last_snapshot or {})},
)
def _register_engine_job(self, job: SchedulerPendingJob) -> None:
self.engine_job_registry.register(job, keep_job=True)
def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
return self.engine_job_registry.get(request_id)
def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
return self.engine_job_registry.pop(request_id)
def _snapshot_engine_job_registry(self) -> Dict[str, Any]:
return self.engine_job_registry.snapshot(max_request_ids=32)
def _is_engine_drained(self) -> bool:
prepare_empty = self.engine_prepare_queue_owner.is_drained()
dispatch_empty = self.engine_dispatch_queue_owner.is_drained()
finalize_empty = self.engine_finalize_queue_owner.is_drained()
decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs()
job_empty = self.engine_job_registry.is_empty()
worker_state = self.scheduler_worker.snapshot()
return bool(
prepare_empty
and dispatch_empty
and finalize_empty
and decode_pending_empty
and job_empty
and self.engine_decode_runtime_owner.get_active_batch() is None
and int(worker_state.get("prepare_inflight", 0)) <= 0
and int(worker_state.get("finalize_inflight", 0)) <= 0
and int(worker_state.get("finalize_pending", 0)) <= 0
)
def _record_engine_job_done(self, request_id: str) -> None:
self.engine_job_registry.mark_finished_and_remove(request_id)
self.scheduler_worker.record_external_job_done(request_id)
def _complete_engine_job(
self,
job: SchedulerPendingJob,
item: T2SFinishedItem,
*,
sample_rate: int,
audio_data: np.ndarray,
) -> None:
completion_bridge = self.scheduler_worker.completion_bridge
completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data)
completion_bridge.complete_job(
job,
runtime_request_id=job.engine_request_id,
runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate),
on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid),
)
def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None:
if not request_ids:
return
completion_bridge = self.scheduler_worker.completion_bridge
for request_id in request_ids:
job = self._get_engine_job(request_id)
if job is None:
continue
completion_bridge.fail_job(
job,
error=error,
on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid),
)
def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
for job in jobs:
job.prefill_ms += delta_ms
def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
for request_id in request_ids:
job = self._get_engine_job(request_id)
if job is not None:
job.merge_ms += delta_ms
def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
activate_request_ids: List[str] = []
for request_id in request_ids:
job = self._get_engine_job(request_id)
if job is None:
continue
if job.decode_steps == 0:
activate_request_ids.append(job.engine_request_id)
job.decode_ms += delta_ms
job.decode_steps += 1
for engine_request_id in activate_request_ids:
self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None)
def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None:
if not items:
return
enqueued_at = time.perf_counter()
tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items]
self._enqueue_worker_finished_for_finalize(tasks)
def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]:
return self.engine_decode_runtime_owner.snapshot_pending_queue_state()
@staticmethod
def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]:
return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch)
def _refresh_engine_decode_runtime_state(self, last_event: str) -> None:
self.engine_decode_runtime_owner.refresh_state(last_event)
def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None:
if not snapshot:
return
if self.scheduler_worker.is_engine_decode_control_enabled():
return
self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot)
def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]:
return self.engine_decode_runtime_owner.snapshot_state()
def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]:
return self.engine_policy_arbiter.snapshot_state()
def _notify_engine_arbiter(self) -> None:
self.engine_policy_arbiter.notify()
def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None:
self.engine_stage_coordinator.decode_runtime_owner.enqueue_pending_job(job)
self._notify_engine_arbiter()
def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
return self.engine_stage_coordinator.decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch)
def _peek_queue_age_ms(self, queue_name: str) -> float:
return self.engine_stage_coordinator.peek_queue_age_ms(queue_name)
def _engine_has_pending_work(self) -> bool:
return self.engine_stage_coordinator.has_pending_work()
async def _prepare_state_via_engine_gpu_queue(
self,
*,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
engine_request_id: str | None,
) -> tuple[T2SRequestState, float, float]:
return await self.engine_stage_coordinator.prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=prepare_submit_at,
engine_request_id=engine_request_id,
)
def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
self.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks)
def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
return self.engine_stage_coordinator.take_engine_finalize_batch_nonblocking()
async def _enqueue_prepared_state_for_dispatch(
self,
*,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None,
done_future: asyncio.Future | None,
engine_request_id: str | None,
timeout_sec: float | None,
) -> EngineDispatchTask:
return await self.engine_stage_coordinator.enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=speed_factor,
sample_steps=sample_steps,
media_type=media_type,
prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=done_loop,
done_future=done_future,
engine_request_id=engine_request_id,
timeout_sec=timeout_sec,
)
def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed)
def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage()
self.owner.engine_dispatch_last_snapshot = dict(policy_snapshot)
return stage, reason, policy_snapshot, worker_state
def _run_engine_prepare_once(self) -> bool:
return self.engine_stage_coordinator.run_engine_prepare_once()
def _run_engine_finalize_once(self) -> bool:
return self.engine_stage_coordinator.run_engine_finalize_once()
def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
return self.engine_stage_coordinator.run_engine_dispatch_once(policy_snapshot, worker_state)
def _run_engine_decode_runtime_once(self) -> bool:
return self.engine_stage_coordinator.run_engine_decode_runtime_once()
def _run_engine_arbiter_loop(self) -> None:
self.engine_stage_coordinator.run_engine_arbiter_loop()

View File

@ -0,0 +1,179 @@
from __future__ import annotations
import os
import threading
from typing import Any
from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import (
EngineArbiterConfig,
EngineDecodeRuntimeOwner,
EnginePolicyArbiterController,
EnginePolicyConfig,
EngineRequestRegistry,
EngineTaskQueueOwner,
ModelRegistry,
ReferenceRegistry,
RuntimeStateCallbacks,
SchedulerJobRegistry,
)
from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage import EngineStageCoordinator
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
class EngineCompositionBuilder:
def __init__(self, owner: Any) -> None:
self.owner = owner
def build(self, *, max_steps: int, micro_batch_wait_ms: int) -> None:
self._init_registries_and_locks()
self._init_worker(max_steps=max_steps, micro_batch_wait_ms=micro_batch_wait_ms)
self._init_policy_configs(micro_batch_wait_ms=micro_batch_wait_ms)
self._init_runtime_owners()
self._init_stage_coordinator()
self._init_arbiter()
self._init_facades()
self._start_arbiter_thread()
def _init_registries_and_locks(self) -> None:
owner = self.owner
owner.reference_registry = ReferenceRegistry()
owner.model_registry = ModelRegistry(
t2s_weights_path=str(owner.tts.configs.t2s_weights_path),
vits_weights_path=str(owner.tts.configs.vits_weights_path),
)
owner.request_registry = EngineRequestRegistry(
recent_limit=max(1, int(os.environ.get("GPTSOVITS_ENGINE_RECENT_REQUEST_LIMIT", "64")))
)
owner.engine_job_registry = SchedulerJobRegistry(threading.Lock())
owner.direct_tts_lock = threading.RLock()
owner.management_lock = threading.RLock()
owner.engine_dispatch_last_snapshot = {}
def _init_worker(self, *, max_steps: int, micro_batch_wait_ms: int) -> None:
owner = self.owner
owner.scheduler_worker = UnifiedSchedulerWorker(
owner.tts,
max_steps=max_steps,
micro_batch_wait_ms=micro_batch_wait_ms,
runtime_callbacks=RuntimeStateCallbacks(
update=owner._update_request_state,
complete=owner._complete_request_state,
fail=owner._fail_request_state,
decode_runtime_update=owner._update_engine_decode_runtime_state,
),
external_finalize_submit=owner._enqueue_worker_finished_for_finalize,
)
def _init_policy_configs(self, *, micro_batch_wait_ms: int) -> None:
owner = self.owner
worker_capacity_limits = owner.scheduler_worker.get_capacity_limits()
prepare_max_inflight = int(owner.scheduler_worker.get_prepare_max_inflight())
owner.engine_policy_config = EnginePolicyConfig(
enabled=owner._env_flag("GPTSOVITS_ENGINE_POLICY_ENABLE", True),
poll_wait_ms=max(1.0, owner._env_float("GPTSOVITS_ENGINE_POLICY_POLL_WAIT_MS", float(micro_batch_wait_ms))),
decode_backlog_soft_max=max(
0,
owner._env_int(
"GPTSOVITS_ENGINE_POLICY_DECODE_BACKLOG_SOFT_MAX",
int(worker_capacity_limits["decode_backlog_max"]),
),
),
finalize_pending_soft_max=max(
0,
owner._env_int(
"GPTSOVITS_ENGINE_POLICY_FINALIZE_PENDING_SOFT_MAX",
int(worker_capacity_limits["finalize_pending_max"]),
),
),
prepare_inflight_soft_max=max(
0,
owner._env_int("GPTSOVITS_ENGINE_POLICY_PREPARE_INFLIGHT_SOFT_MAX", prepare_max_inflight),
),
active_decode_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_DECODE_SOFT_MAX", 0)),
ready_for_prefill_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_READY_FOR_PREFILL_SOFT_MAX", 0)),
active_request_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_REQUEST_SOFT_MAX", 0)),
)
owner.engine_arbiter_config = EngineArbiterConfig(
poll_wait_ms=max(1.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_POLL_WAIT_MS", float(micro_batch_wait_ms))),
decode_burst=max(1, owner._env_int("GPTSOVITS_ENGINE_ARBITER_DECODE_BURST", 4)),
prepare_aging_ms=max(0.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_PREPARE_AGING_MS", 10.0)),
finalize_aging_ms=max(0.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_FINALIZE_AGING_MS", 10.0)),
)
def _init_runtime_owners(self) -> None:
owner = self.owner
owner.engine_decode_runtime_owner = EngineDecodeRuntimeOwner(
get_decode_runtime_counters=owner.scheduler_worker.get_decode_runtime_counters,
get_micro_batch_wait_s=owner.scheduler_worker.get_micro_batch_wait_s,
)
owner.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
owner.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
owner.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched")
def _init_stage_coordinator(self) -> None:
owner = self.owner
owner.engine_stage_coordinator = EngineStageCoordinator(
tts=owner.tts,
scheduler_worker=owner.scheduler_worker,
prepare_queue_owner=owner.engine_prepare_queue_owner,
finalize_queue_owner=owner.engine_finalize_queue_owner,
dispatch_queue_owner=owner.engine_dispatch_queue_owner,
decode_runtime_owner=owner.engine_decode_runtime_owner,
update_request_state=owner._update_request_state,
merge_request_state_profile=owner._merge_request_state_profile,
fail_request_state=owner._fail_request_state,
get_engine_job=owner._get_engine_job,
register_engine_job=owner._register_engine_job,
fail_engine_jobs=owner._fail_engine_jobs,
complete_engine_job=owner._complete_engine_job,
add_engine_prefill_time=owner._add_engine_prefill_time,
add_engine_merge_time=owner._add_engine_merge_time,
add_engine_decode_time=owner._add_engine_decode_time,
enqueue_engine_finished_items=owner._enqueue_engine_finished_items,
snapshot_engine_dispatch_state=owner._snapshot_engine_dispatch_state,
snapshot_engine_decode_runtime_state=owner._snapshot_engine_decode_runtime_state,
)
def _init_arbiter(self) -> None:
owner = self.owner
owner.engine_policy_arbiter = EnginePolicyArbiterController(
policy_config=owner.engine_policy_config,
arbiter_config=owner.engine_arbiter_config,
snapshot_request_registry=owner._snapshot_request_registry,
get_worker_state=owner.get_scheduler_state,
snapshot_prepare_state=owner._snapshot_engine_prepare_state,
snapshot_finalize_state=owner._snapshot_engine_finalize_state,
snapshot_dispatch_state=owner._snapshot_engine_dispatch_state,
snapshot_decode_runtime_state=owner._snapshot_engine_decode_runtime_state,
snapshot_job_registry=owner._snapshot_engine_job_registry,
peek_queue_age_ms=owner.engine_stage_coordinator.peek_queue_age_ms,
merge_request_state_profile=owner._merge_request_state_profile,
)
owner.engine_stage_coordinator.bind_arbiter(
notify_arbiter=owner._notify_engine_arbiter,
select_stage=owner._select_engine_stage,
mark_arbiter_tick=lambda stage, reason, policy_allowed: owner._mark_arbiter_tick(
stage=stage,
reason=reason,
policy_allowed=policy_allowed,
),
wait_arbiter=owner.engine_policy_arbiter.wait,
)
def _init_facades(self) -> None:
owner = self.owner
owner.bridge_facade = EngineBridgeFacade(owner)
owner.api_facade = EngineApiFacade(owner)
owner.runtime_facade = EngineRuntimeFacade(owner)
def _start_arbiter_thread(self) -> None:
owner = self.owner
owner.engine_arbiter_thread = threading.Thread(
target=owner._run_engine_arbiter_loop,
name="unified-engine-arbiter",
daemon=True,
)
owner.engine_arbiter_thread.start()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,446 @@
from __future__ import annotations
import asyncio
from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple
import numpy as np
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, EngineDispatchTask, EngineRequestState, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerFinalizeTask, SchedulerPendingJob, SchedulerSubmitExecution
from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade
class EngineBridgeDelegates:
def _register_request_state(
self,
request_id: str,
api_mode: str,
backend: str,
media_type: str,
response_streaming: bool,
deadline_ts: float | None = None,
meta: Optional[Dict[str, Any]] = None,
) -> EngineRequestState:
return self.bridge_facade._register_request_state(
request_id=request_id,
api_mode=api_mode,
backend=backend,
media_type=media_type,
response_streaming=response_streaming,
deadline_ts=deadline_ts,
meta=meta,
)
def _update_request_state(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.bridge_facade._update_request_state(request_id, status, extra)
def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.bridge_facade._merge_request_state_profile(request_id, extra)
def _snapshot_engine_prepare_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_prepare_state()
def _snapshot_engine_finalize_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_finalize_state()
def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_dispatch_state()
def _register_engine_job(self, job: SchedulerPendingJob) -> None:
self.bridge_facade._register_engine_job(job)
def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
return self.bridge_facade._get_engine_job(request_id)
def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
return self.bridge_facade._pop_engine_job(request_id)
def _snapshot_engine_job_registry(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_job_registry()
def _is_engine_drained(self) -> bool:
return self.bridge_facade._is_engine_drained()
def _record_engine_job_done(self, request_id: str) -> None:
self.bridge_facade._record_engine_job_done(request_id)
def _complete_engine_job(
self,
job: SchedulerPendingJob,
item: T2SFinishedItem,
*,
sample_rate: int,
audio_data: np.ndarray,
) -> None:
self.bridge_facade._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None:
self.bridge_facade._fail_engine_jobs(request_ids, error)
def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None:
self.bridge_facade._add_engine_prefill_time(jobs, elapsed_s)
def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
self.bridge_facade._add_engine_merge_time(request_ids, elapsed_s)
def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
self.bridge_facade._add_engine_decode_time(request_ids, elapsed_s)
def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None:
self.bridge_facade._enqueue_engine_finished_items(items)
def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_decode_pending_queue_state()
@staticmethod
def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]:
return EngineBridgeFacade._summarize_active_batch(active_batch)
def _refresh_engine_decode_runtime_state(self, last_event: str) -> None:
self.bridge_facade._refresh_engine_decode_runtime_state(last_event)
def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None:
self.bridge_facade._update_engine_decode_runtime_state(snapshot)
def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_decode_runtime_state()
def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_arbiter_state()
def _notify_engine_arbiter(self) -> None:
self.bridge_facade._notify_engine_arbiter()
def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None:
self.bridge_facade._enqueue_engine_decode_pending_job(job)
def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
return self.bridge_facade._take_engine_decode_pending_jobs_nonblocking(wait_for_batch)
def _peek_queue_age_ms(self, queue_name: str) -> float:
return self.bridge_facade._peek_queue_age_ms(queue_name)
def _engine_has_pending_work(self) -> bool:
return self.bridge_facade._engine_has_pending_work()
async def _prepare_state_via_engine_gpu_queue(
self,
*,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
engine_request_id: str | None,
) -> tuple[T2SRequestState, float, float]:
return await self.bridge_facade._prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=prepare_submit_at,
engine_request_id=engine_request_id,
)
def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
self.bridge_facade._enqueue_worker_finished_for_finalize(tasks)
def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
return self.bridge_facade._take_engine_finalize_batch_nonblocking()
async def _enqueue_prepared_state_for_dispatch(
self,
*,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None,
done_future: asyncio.Future | None,
engine_request_id: str | None,
timeout_sec: float | None,
) -> EngineDispatchTask:
return await self.bridge_facade._enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=speed_factor,
sample_steps=sample_steps,
media_type=media_type,
prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=done_loop,
done_future=done_future,
engine_request_id=engine_request_id,
timeout_sec=timeout_sec,
)
def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
self.bridge_facade._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed)
def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
return self.bridge_facade._select_engine_stage()
def _run_engine_prepare_once(self) -> bool:
return self.bridge_facade._run_engine_prepare_once()
def _run_engine_finalize_once(self) -> bool:
return self.bridge_facade._run_engine_finalize_once()
def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
return self.bridge_facade._run_engine_dispatch_once(policy_snapshot, worker_state)
def _run_engine_decode_runtime_once(self) -> bool:
return self.bridge_facade._run_engine_decode_runtime_once()
def _run_engine_arbiter_loop(self) -> None:
self.bridge_facade._run_engine_arbiter_loop()
def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.bridge_facade._complete_request_state(request_id, extra)
def _fail_request_state(self, request_id: str, error: str) -> None:
self.bridge_facade._fail_request_state(request_id, error)
def _snapshot_request_registry(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_request_registry()
class EngineApiDelegates:
def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]:
return self.api_facade._collect_request_summaries(request_ids)
def _has_active_request(self, request_id: str) -> bool:
return self.api_facade._has_active_request(request_id)
@staticmethod
def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]:
return EngineApiFacade._build_request_meta(payload)
@staticmethod
def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float:
return EngineApiFacade._sum_profile_field(items, key)
def _build_direct_segment_trace(
self,
segment_texts: Sequence[str],
prepare_profiles: Sequence[Dict[str, Any]],
worker_profiles: Sequence[Dict[str, Any]],
) -> List[Dict[str, Any]]:
return self.api_facade._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles)
def _build_direct_scheduler_profile(self, **kwargs: Any) -> Dict[str, Any]:
return self.api_facade._build_direct_scheduler_profile(**kwargs)
def _build_legacy_direct_profile(self, **kwargs: Any) -> Dict[str, Any]:
return self.api_facade._build_legacy_direct_profile(**kwargs)
def _build_scheduler_submit_profile(self, **kwargs: Any) -> Dict[str, Any]:
return self.api_facade._build_scheduler_submit_profile(**kwargs)
@staticmethod
def _format_ms_header(value: Any) -> str:
return EngineApiFacade._format_ms_header(value)
def _build_scheduler_submit_headers(
self,
*,
request_id: str,
media_type: str,
sample_rate: int,
profile: Dict[str, Any],
) -> Dict[str, str]:
return self.api_facade._build_scheduler_submit_headers(
request_id=request_id,
media_type=media_type,
sample_rate=sample_rate,
profile=profile,
)
def _build_scheduler_debug_request_profile(self, **kwargs: Any) -> Dict[str, Any]:
return self.api_facade._build_scheduler_debug_request_profile(**kwargs)
@staticmethod
def _build_scheduler_debug_batch_profile(**kwargs: Any) -> Dict[str, Any]:
return EngineApiFacade._build_scheduler_debug_batch_profile(**kwargs)
def _normalize_lang(self, value: str | None) -> str | None:
return self.api_facade._normalize_lang(value)
@staticmethod
def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]:
return EngineApiFacade._aggregate_numeric_dicts(items)
def _apply_default_reference(self, req: dict) -> dict:
return self.api_facade._apply_default_reference(req)
def check_params(self, req: dict) -> Optional[str]:
return self.api_facade.check_params(req)
@staticmethod
def _base_request_defaults() -> Dict[str, Any]:
return EngineApiFacade._base_request_defaults()
def _normalize_engine_request(
self,
payload: dict | NormalizedEngineRequest,
*,
request_id: str | None = None,
normalize_streaming: bool = False,
error_prefix: str = "request 参数非法: ",
) -> NormalizedEngineRequest:
return self.api_facade._normalize_engine_request(
payload,
request_id=request_id,
normalize_streaming=normalize_streaming,
error_prefix=error_prefix,
)
@staticmethod
def _normalize_streaming_mode(req: dict) -> dict:
return EngineApiFacade._normalize_streaming_mode(req)
@staticmethod
def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool:
return EngineApiFacade._is_aux_ref_enabled(aux_ref_audio_paths)
def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]:
return self.api_facade._select_direct_backend(normalized)
def _iter_legacy_direct_tts_bytes(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> Generator[bytes, None, None]:
return self.api_facade._iter_legacy_direct_tts_bytes(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool:
return self.api_facade._should_use_scheduler_backend_for_direct(req)
def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]:
return self.api_facade._segment_direct_text(normalized)
def _build_segment_request(
self,
normalized: NormalizedEngineRequest,
*,
request_id: str,
text: str,
) -> NormalizedEngineRequest:
return self.api_facade._build_segment_request(normalized, request_id=request_id, text=text)
async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution:
return await self.api_facade._run_direct_tts_via_scheduler(normalized)
def _run_legacy_direct_tts_blocking(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> DirectTTSExecution:
return self.api_facade._run_legacy_direct_tts_blocking(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
async def _run_direct_tts_via_legacy_backend(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> DirectTTSExecution:
return await self.api_facade._run_direct_tts_via_legacy_backend(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution:
return await self.api_facade.run_direct_tts_async(req)
def run_direct_tts(self, req: dict) -> DirectTTSExecution:
return self.api_facade.run_direct_tts(req)
def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]:
return self.api_facade.build_scheduler_request_specs(request_items)
def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec:
return self.api_facade.build_scheduler_submit_spec(payload)
@staticmethod
def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]:
return EngineApiFacade.summarize_scheduler_states(states)
@staticmethod
def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
return EngineApiFacade.summarize_scheduler_finished(items)
async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed)
async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution:
return await self.api_facade.run_scheduler_submit(payload)
class EngineRuntimeDelegates:
@staticmethod
def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None:
return EngineRuntimeFacade._safe_component_snapshot(component)
def _build_stage_counters(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
return self.runtime_facade._build_stage_counters(request_registry, worker_state)
def _build_engine_policy_snapshot(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
return self.runtime_facade._build_engine_policy_snapshot(request_registry, worker_state)
async def _wait_for_engine_policy_admission(
self,
*,
request_id: str | None,
timeout_sec: float | None,
) -> tuple[float, Dict[str, Any]]:
return await self.engine_policy_arbiter.wait_for_policy_admission(
request_id=request_id,
timeout_sec=timeout_sec,
)
def _build_stage_summary(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
return self.runtime_facade._build_stage_summary(request_registry, worker_state)
def get_scheduler_state(self) -> dict:
return self.runtime_facade.get_scheduler_state()
def get_runtime_state(self) -> dict:
return self.runtime_facade.get_runtime_state()
def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None:
self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec)
def set_refer_audio(self, refer_audio_path: str | None) -> dict:
return self.runtime_facade.set_refer_audio(refer_audio_path)
def set_gpt_weights(self, weights_path: str) -> dict:
return self.runtime_facade.set_gpt_weights(weights_path)
def set_sovits_weights(self, weights_path: str) -> dict:
return self.runtime_facade.set_sovits_weights(weights_path)
def handle_control(self, command: str) -> None:
self.runtime_facade.handle_control(command)

View File

@ -0,0 +1,198 @@
from __future__ import annotations
import os
import signal
import sys
from typing import Any, Dict, Optional
class EngineRuntimeFacade:
def __init__(self, owner: Any) -> None:
self.owner = owner
@property
def tts(self):
return self.owner.tts
@property
def reference_registry(self):
return self.owner.reference_registry
@property
def model_registry(self):
return self.owner.model_registry
@property
def scheduler_worker(self):
return self.owner.scheduler_worker
@property
def engine_decode_runtime_owner(self):
return self.owner.engine_decode_runtime_owner
@property
def engine_policy_arbiter(self):
return self.owner.engine_policy_arbiter
@property
def management_lock(self):
return self.owner.management_lock
@property
def direct_tts_lock(self):
return self.owner.direct_tts_lock
@property
def control_callbacks(self):
return self.owner.control_callbacks
@staticmethod
def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None:
if component is None or not hasattr(component, "snapshot"):
return None
try:
return dict(component.snapshot())
except Exception:
return None
def _build_stage_counters(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
return self.engine_policy_arbiter.build_stage_counters(request_registry, worker_state)
def _build_engine_policy_snapshot(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
return self.engine_policy_arbiter.build_policy_snapshot(request_registry, worker_state)
def _build_stage_summary(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
counters = self._build_stage_counters(request_registry, worker_state)
bert_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_bert_batch_worker", None))
ref_semantic_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_ref_semantic_batch_worker", None))
text_preprocessor_state = self._safe_component_snapshot(getattr(self.tts, "text_preprocessor", None))
return {
**counters,
"engine_drained": bool(self.owner._is_engine_drained()),
"admission_config": {
"decode_backlog_max": int(worker_state.get("decode_backlog_max", 0)),
"finalize_pending_max": int(worker_state.get("finalize_pending_max", 0)),
},
"engine_policy": self._build_engine_policy_snapshot(request_registry, worker_state),
"engine_arbiter_state": self.owner._snapshot_engine_arbiter_state(),
"engine_decode_runtime_state": self.owner._snapshot_engine_decode_runtime_state(),
"engine_job_registry": self.owner._snapshot_engine_job_registry(),
"engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(),
"engine_prepare_state": self.owner._snapshot_engine_prepare_state(),
"engine_finalize_state": self.owner._snapshot_engine_finalize_state(),
"engine_dispatcher_state": self.owner._snapshot_engine_dispatch_state(),
"active_batch": dict(worker_state.get("active_batch") or {}),
"prepare_state": dict(worker_state.get("prepare_state") or {}),
"bert_batch_worker_state": bert_worker_state,
"ref_semantic_worker_state": ref_semantic_worker_state,
"text_preprocessor_state": text_preprocessor_state,
}
def get_scheduler_state(self) -> dict:
return self.scheduler_worker.snapshot()
def get_runtime_state(self) -> dict:
model_state = self.model_registry.snapshot()
default_ref = self.reference_registry.get_default()
scheduler_state = self.get_scheduler_state()
request_registry = self.owner._snapshot_request_registry()
engine_policy = self._build_engine_policy_snapshot(request_registry, scheduler_state)
engine_arbiter_state = self.owner._snapshot_engine_arbiter_state()
engine_decode_runtime_state = self.owner._snapshot_engine_decode_runtime_state()
engine_job_registry = self.owner._snapshot_engine_job_registry()
engine_prepare_state = self.owner._snapshot_engine_prepare_state()
engine_finalize_state = self.owner._snapshot_engine_finalize_state()
engine_dispatcher_state = self.owner._snapshot_engine_dispatch_state()
engine_drained = self.owner._is_engine_drained()
return {
"message": "success",
"default_reference": {
"ref_audio_path": default_ref.ref_audio_path,
"updated_at": default_ref.updated_at,
},
"model_registry": {
"generation": model_state.generation,
"t2s_generation": model_state.t2s_generation,
"vits_generation": model_state.vits_generation,
"t2s_weights_path": model_state.t2s_weights_path,
"vits_weights_path": model_state.vits_weights_path,
"updated_at": model_state.updated_at,
},
"worker_state": scheduler_state,
"engine_policy": engine_policy,
"engine_arbiter_state": engine_arbiter_state,
"engine_decode_runtime_state": engine_decode_runtime_state,
"engine_job_registry": engine_job_registry,
"engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(),
"engine_prepare_state": engine_prepare_state,
"engine_finalize_state": engine_finalize_state,
"engine_dispatcher_state": engine_dispatcher_state,
"engine_drained": bool(engine_drained),
"request_registry": request_registry,
"stage_summary": self._build_stage_summary(request_registry, scheduler_state),
}
def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None:
if not self.scheduler_worker.wait_until_idle(timeout_sec=timeout_sec):
raise TimeoutError("scheduler worker did not drain before model reload")
def set_refer_audio(self, refer_audio_path: str | None) -> dict:
if refer_audio_path in [None, ""]:
state = self.reference_registry.clear()
return {"message": "success", "default_ref_audio_path": state.ref_audio_path}
if not os.path.exists(str(refer_audio_path)):
raise FileNotFoundError(f"{refer_audio_path} not exists")
with self.management_lock:
with self.direct_tts_lock:
self.tts.set_ref_audio(str(refer_audio_path))
state = self.reference_registry.set_default(str(refer_audio_path))
return {"message": "success", "default_ref_audio_path": state.ref_audio_path}
def set_gpt_weights(self, weights_path: str) -> dict:
if weights_path in ["", None]:
raise ValueError("gpt weight path is required")
with self.management_lock:
self._wait_for_safe_reload()
with self.direct_tts_lock:
self.tts.init_t2s_weights(weights_path)
self.tts.refresh_runtime_components()
state = self.model_registry.mark_t2s_reload(str(weights_path))
return {"message": "success", "t2s_generation": state.t2s_generation, "generation": state.generation}
def set_sovits_weights(self, weights_path: str) -> dict:
if weights_path in ["", None]:
raise ValueError("sovits weight path is required")
with self.management_lock:
self._wait_for_safe_reload()
with self.direct_tts_lock:
self.tts.init_vits_weights(weights_path)
self.tts.refresh_runtime_components()
state = self.model_registry.mark_vits_reload(str(weights_path))
return {"message": "success", "vits_generation": state.vits_generation, "generation": state.generation}
def handle_control(self, command: str) -> None:
if command == "restart":
if self.control_callbacks.restart is None:
os.execl(sys.executable, sys.executable, *sys.argv)
self.control_callbacks.restart()
return
if command == "exit":
if self.control_callbacks.exit is None:
os.kill(os.getpid(), signal.SIGTERM)
return
self.control_callbacks.exit()
return
raise ValueError(f"unsupported command: {command}")

View File

@ -0,0 +1,420 @@
from __future__ import annotations
import asyncio
import time
from typing import Any, Callable, Dict, List, Optional
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import (
EngineDecodeRuntimeOwner,
EngineDispatchTask,
EngineGpuPrepareTask,
EngineStatus,
EngineTaskQueueOwner,
SchedulerFinalizeTask,
SchedulerPendingJob,
)
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
class EngineStageCoordinator:
def __init__(
self,
*,
tts: TTS,
scheduler_worker: UnifiedSchedulerWorker,
prepare_queue_owner: EngineTaskQueueOwner,
finalize_queue_owner: EngineTaskQueueOwner,
dispatch_queue_owner: EngineTaskQueueOwner,
decode_runtime_owner: EngineDecodeRuntimeOwner,
update_request_state: Callable[[str, str, Optional[Dict[str, Any]]], None],
merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None],
fail_request_state: Callable[[str, str], None],
get_engine_job: Callable[[str], SchedulerPendingJob | None],
register_engine_job: Callable[[SchedulerPendingJob], None],
fail_engine_jobs: Callable[[List[str], str], None],
complete_engine_job: Callable[..., None],
add_engine_prefill_time: Callable[[List[SchedulerPendingJob], float], None],
add_engine_merge_time: Callable[[List[str], float], None],
add_engine_decode_time: Callable[[List[str], float], None],
enqueue_engine_finished_items: Callable[[List[T2SFinishedItem]], None],
snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]],
snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]],
) -> None:
self.tts = tts
self.scheduler_worker = scheduler_worker
self.prepare_queue_owner = prepare_queue_owner
self.finalize_queue_owner = finalize_queue_owner
self.dispatch_queue_owner = dispatch_queue_owner
self.decode_runtime_owner = decode_runtime_owner
self.update_request_state = update_request_state
self.merge_request_state_profile = merge_request_state_profile
self.fail_request_state = fail_request_state
self.get_engine_job = get_engine_job
self.register_engine_job = register_engine_job
self.fail_engine_jobs = fail_engine_jobs
self.complete_engine_job = complete_engine_job
self.add_engine_prefill_time = add_engine_prefill_time
self.add_engine_merge_time = add_engine_merge_time
self.add_engine_decode_time = add_engine_decode_time
self.enqueue_engine_finished_items = enqueue_engine_finished_items
self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state
self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state
self._notify_arbiter: Callable[[], None] | None = None
self._select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]] | None = None
self._mark_arbiter_tick: Callable[[str, str, bool], None] | None = None
self._wait_arbiter: Callable[[], None] | None = None
def bind_arbiter(
self,
*,
notify_arbiter: Callable[[], None],
select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]],
mark_arbiter_tick: Callable[[str, str, bool], None],
wait_arbiter: Callable[[], None],
) -> None:
self._notify_arbiter = notify_arbiter
self._select_stage = select_stage
self._mark_arbiter_tick = mark_arbiter_tick
self._wait_arbiter = wait_arbiter
def notify_arbiter(self) -> None:
if self._notify_arbiter is not None:
self._notify_arbiter()
@staticmethod
def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None:
if future.done():
return
future.set_exception(error)
def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None:
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error)
except RuntimeError:
pass
@staticmethod
def _resolve_prepare_future(
future: asyncio.Future,
payload: tuple[T2SRequestState, float, float],
) -> None:
if future.done():
return
future.set_result(payload)
def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None:
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error)
except RuntimeError:
pass
def _notify_prepare_result(
self,
task: EngineGpuPrepareTask,
payload: tuple[T2SRequestState, float, float],
) -> None:
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload)
except RuntimeError:
pass
async def prepare_state_via_engine_gpu_queue(
self,
*,
spec,
prepare_submit_at: float,
engine_request_id: str | None,
) -> tuple[T2SRequestState, float, float]:
cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at)
if engine_request_id not in [None, ""]:
self.update_request_state(
str(engine_request_id),
EngineStatus.GPU_PREPARING,
{
"prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms),
"prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms),
"text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms),
"text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms),
},
)
loop = asyncio.get_running_loop()
done_future = loop.create_future()
task = EngineGpuPrepareTask(
request_id=spec.request_id,
cpu_stage=cpu_stage,
done_loop=loop,
done_future=done_future,
engine_request_id=engine_request_id or spec.request_id,
enqueue_time=time.perf_counter(),
)
self.prepare_queue_owner.enqueue(task)
self.notify_arbiter()
state, prepare_exec_started_at, prepare_exec_finished_at = await done_future
return state, prepare_exec_started_at, prepare_exec_finished_at
def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
if not tasks:
return
for task in tasks:
job = self.get_engine_job(task.request_id)
if job is not None:
self.update_request_state(
job.engine_request_id,
EngineStatus.READY_FOR_FINALIZE,
{
"finish_reason": task.item.finish_reason,
"semantic_len": int(task.item.semantic_tokens.shape[0]),
"finish_idx": int(task.item.finish_idx),
},
)
self.finalize_queue_owner.enqueue_many(tasks)
self.notify_arbiter()
def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
finalize_policy = self.scheduler_worker.get_finalize_batch_policy()
return self.finalize_queue_owner.take_finalize_batch(
finalize_mode=str(finalize_policy.get("finalize_mode", "async")),
batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)),
batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)),
use_vocoder=bool(self.tts.configs.use_vocoder),
)
async def enqueue_prepared_state_for_dispatch(
self,
*,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None,
done_future: asyncio.Future | None,
engine_request_id: str | None,
timeout_sec: float | None,
) -> EngineDispatchTask:
task = EngineDispatchTask(
request_id=state.request_id,
state=state,
speed_factor=float(speed_factor),
sample_steps=int(sample_steps),
media_type=media_type,
prepare_wall_ms=float(prepare_wall_ms),
prepare_profile_total_ms=float(prepare_profile_total_ms),
done_loop=done_loop,
done_future=done_future,
engine_request_id=engine_request_id or state.request_id,
timeout_sec=timeout_sec,
enqueue_time=time.perf_counter(),
)
self.dispatch_queue_owner.enqueue(task)
self.notify_arbiter()
self.merge_request_state_profile(
task.engine_request_id or task.request_id,
{
"engine_dispatch_queue_depth_on_enqueue": int(
self.snapshot_engine_dispatch_state()["waiting_count"]
),
},
)
return task
def peek_queue_age_ms(self, queue_name: str) -> float:
if queue_name == "prepare":
return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time")
if queue_name == "finalize":
return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time")
if queue_name == "decode_runtime_pending":
return self.decode_runtime_owner.pending_age_ms()
return self.dispatch_queue_owner.peek_oldest_age_ms("enqueue_time")
def has_pending_work(self) -> bool:
if self.scheduler_worker.is_engine_decode_control_enabled():
if self.decode_runtime_owner.has_pending_jobs():
return True
if self.scheduler_worker.is_engine_decode_control_enabled() and self.snapshot_engine_decode_runtime_state().get(
"active_request_count", 0
) > 0:
return True
if self.prepare_queue_owner.has_items():
return True
if self.finalize_queue_owner.has_items():
return True
return self.dispatch_queue_owner.has_items()
def run_engine_prepare_once(self) -> bool:
task = self.prepare_queue_owner.pop_left()
if task is None:
return False
queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0)
try:
state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run(
self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage)
)
state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms)
if task.engine_request_id not in [None, ""]:
self.merge_request_state_profile(
str(task.engine_request_id),
{"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)},
)
self.prepare_queue_owner.mark_completed(1)
self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at))
return True
except Exception as exc:
task.error = str(exc)
self.fail_request_state(task.engine_request_id or task.request_id, str(exc))
self._notify_prepare_error(task, exc)
return True
def run_engine_finalize_once(self) -> bool:
tasks = self.take_engine_finalize_batch_nonblocking()
if not tasks:
return False
self.scheduler_worker.begin_finalize_execution(len(tasks))
try:
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
for task in tasks:
job = self.get_engine_job(task.request_id)
if job is None:
continue
jobs_and_items.append((job, task.item))
if not jobs_and_items:
return False
now = time.perf_counter()
for task in tasks:
job = self.get_engine_job(task.request_id)
if job is not None:
job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0)
for job, item in jobs_and_items:
self.update_request_state(
job.engine_request_id,
EngineStatus.FINALIZING,
{
"finish_reason": item.finish_reason,
"semantic_len": int(item.semantic_tokens.shape[0]),
},
)
synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items)
for job, _ in jobs_and_items:
job.synth_ms += float(synth_ms)
for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results):
self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
except Exception as exc:
self.fail_engine_jobs([task.request_id for task in tasks], str(exc))
finally:
self.scheduler_worker.end_finalize_execution(len(tasks))
self.finalize_queue_owner.mark_completed(len(tasks), notify=True)
return True
def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
if not bool(policy_snapshot.get("allowed", True)):
return False
dispatch_task = self.dispatch_queue_owner.pop_left()
if dispatch_task is None:
return False
dispatched_at = time.perf_counter()
dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0)
dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms)
dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms)
dispatch_task.engine_policy_snapshot = dict(policy_snapshot)
try:
worker_job = self.scheduler_worker.submit(
state=dispatch_task.state,
speed_factor=dispatch_task.speed_factor,
sample_steps=dispatch_task.sample_steps,
media_type=dispatch_task.media_type,
prepare_wall_ms=dispatch_task.prepare_wall_ms,
prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms,
done_loop=dispatch_task.done_loop,
done_future=dispatch_task.done_future,
engine_request_id=dispatch_task.engine_request_id,
timeout_sec=dispatch_task.timeout_sec,
skip_capacity_wait=True,
admission_wait_ms_override=0.0,
admission_snapshot_override=dict(worker_state),
engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms,
engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms,
enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(),
)
dispatch_task.worker_job = worker_job
self.register_engine_job(worker_job)
if self.scheduler_worker.is_engine_decode_control_enabled():
self.decode_runtime_owner.enqueue_pending_job(worker_job)
self.notify_arbiter()
self.dispatch_queue_owner.mark_completed(1)
return True
except Exception as exc:
dispatch_task.error = str(exc)
self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc))
self._notify_dispatch_error(dispatch_task, exc)
return True
def run_engine_decode_runtime_once(self) -> bool:
if not self.scheduler_worker.is_engine_decode_control_enabled():
return False
runtime_state = self.snapshot_engine_decode_runtime_state()
pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking(
wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0
)
result = self.scheduler_worker.execute_decode_cycle(
pending_jobs=pending_jobs,
active_batch=self.decode_runtime_owner.get_active_batch(),
external_bookkeeping=True,
)
prefill_phase = dict(result.get("prefill_phase") or {})
if prefill_phase.get("error"):
self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error")))
else:
prefill_jobs = list(prefill_phase.get("pending_jobs") or [])
self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0)))
self.add_engine_merge_time(
[] if result.get("active_batch") is None else list(result["active_batch"].request_ids),
float(prefill_phase.get("merge_elapsed_s", 0.0)),
)
self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or []))
decode_phase = dict(result.get("decode_phase") or {})
if decode_phase.get("error"):
self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error")))
else:
self.add_engine_decode_time(
list(decode_phase.get("request_ids") or []),
float(decode_phase.get("decode_elapsed_s", 0.0)),
)
self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or []))
self.decode_runtime_owner.set_active_batch(result.get("active_batch"))
if result.get("executed", False):
self.decode_runtime_owner.refresh_state("engine_decode_cycle")
return bool(result.get("executed", False))
def run_engine_arbiter_loop(self) -> None:
if self._select_stage is None or self._mark_arbiter_tick is None or self._wait_arbiter is None:
raise RuntimeError("arbiter callbacks are not bound")
while True:
if not self.has_pending_work():
self._mark_arbiter_tick("idle", "no_pending_work", True)
self._wait_arbiter()
continue
stage, reason, policy_snapshot, worker_state = self._select_stage()
policy_allowed = bool(policy_snapshot.get("allowed", True))
executed = False
if stage == "prepare":
executed = self.run_engine_prepare_once()
elif stage == "finalize":
executed = self.run_engine_finalize_once()
elif stage == "decode_dispatch":
executed = self.run_engine_dispatch_once(policy_snapshot, worker_state)
elif stage == "decode_runtime":
executed = self.run_engine_decode_runtime_once()
if not executed:
self._mark_arbiter_tick("idle", f"{stage}_not_ready", policy_allowed)
self._wait_arbiter()
continue
self._mark_arbiter_tick(stage, reason, policy_allowed)

File diff suppressed because it is too large Load Diff