mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-18 16:28:11 +08:00
Add unified engine components and API for enhanced TTS processing
Introduce multiple new modules including unified_engine_api, unified_engine_audio, unified_engine_bridge, unified_engine_builder, unified_engine_components, unified_engine_delegates, and unified_engine_runtime. These additions provide a comprehensive framework for managing TTS requests, audio packing, and engine state management, significantly improving the architecture and maintainability of the TTS system. The new structure supports asynchronous operations and enhances overall performance through better request handling and processing capabilities.
This commit is contained in:
parent
06d6b67f73
commit
d1ec7d9e54
File diff suppressed because it is too large
Load Diff
1399
GPT_SoVITS/TTS_infer_pack/unified_engine_api.py
Normal file
1399
GPT_SoVITS/TTS_infer_pack/unified_engine_api.py
Normal file
File diff suppressed because it is too large
Load Diff
106
GPT_SoVITS/TTS_infer_pack/unified_engine_audio.py
Normal file
106
GPT_SoVITS/TTS_infer_pack/unified_engine_audio.py
Normal file
@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import threading
|
||||
import wave
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
|
||||
def set_scheduler_seed(seed: int):
|
||||
if seed in ["", None]:
|
||||
return
|
||||
seed = int(seed)
|
||||
if seed < 0:
|
||||
return
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
def handle_pack_ogg():
|
||||
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
|
||||
audio_file.write(data)
|
||||
|
||||
stack_size = 4096 * 4096
|
||||
try:
|
||||
threading.stack_size(stack_size)
|
||||
pack_ogg_thread = threading.Thread(target=handle_pack_ogg)
|
||||
pack_ogg_thread.start()
|
||||
pack_ogg_thread.join()
|
||||
except (RuntimeError, ValueError):
|
||||
handle_pack_ogg()
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
io_buffer.write(data.tobytes())
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
io_buffer = BytesIO()
|
||||
sf.write(io_buffer, data, rate, format="wav")
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-f",
|
||||
"s16le",
|
||||
"-ar",
|
||||
str(rate),
|
||||
"-ac",
|
||||
"1",
|
||||
"-i",
|
||||
"pipe:0",
|
||||
"-c:a",
|
||||
"aac",
|
||||
"-b:a",
|
||||
"192k",
|
||||
"-vn",
|
||||
"-f",
|
||||
"adts",
|
||||
"pipe:1",
|
||||
],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
out, _ = process.communicate(input=data.tobytes())
|
||||
io_buffer.write(out)
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str):
|
||||
if media_type == "ogg":
|
||||
io_buffer = pack_ogg(io_buffer, data, rate)
|
||||
elif media_type == "aac":
|
||||
io_buffer = pack_aac(io_buffer, data, rate)
|
||||
elif media_type == "wav":
|
||||
io_buffer = pack_wav(io_buffer, data, rate)
|
||||
else:
|
||||
io_buffer = pack_raw(io_buffer, data, rate)
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
|
||||
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
|
||||
wav_buf = BytesIO()
|
||||
with wave.open(wav_buf, "wb") as vfout:
|
||||
vfout.setnchannels(channels)
|
||||
vfout.setsampwidth(sample_width)
|
||||
vfout.setframerate(sample_rate)
|
||||
vfout.writeframes(frame_input)
|
||||
wav_buf.seek(0)
|
||||
return wav_buf.read()
|
||||
|
||||
|
||||
310
GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py
Normal file
310
GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py
Normal file
@ -0,0 +1,310 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner, EngineDispatchTask, EngineRequestState, EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob
|
||||
|
||||
|
||||
class EngineBridgeFacade:
|
||||
def __init__(self, owner: Any) -> None:
|
||||
self.owner = owner
|
||||
|
||||
@property
|
||||
def request_registry(self):
|
||||
return self.owner.request_registry
|
||||
|
||||
@property
|
||||
def engine_prepare_queue_owner(self):
|
||||
return self.owner.engine_prepare_queue_owner
|
||||
|
||||
@property
|
||||
def engine_finalize_queue_owner(self):
|
||||
return self.owner.engine_finalize_queue_owner
|
||||
|
||||
@property
|
||||
def engine_dispatch_queue_owner(self):
|
||||
return self.owner.engine_dispatch_queue_owner
|
||||
|
||||
@property
|
||||
def engine_decode_runtime_owner(self):
|
||||
return self.owner.engine_decode_runtime_owner
|
||||
|
||||
@property
|
||||
def engine_job_registry(self):
|
||||
return self.owner.engine_job_registry
|
||||
|
||||
@property
|
||||
def scheduler_worker(self):
|
||||
return self.owner.scheduler_worker
|
||||
|
||||
@property
|
||||
def engine_stage_coordinator(self):
|
||||
return self.owner.engine_stage_coordinator
|
||||
|
||||
@property
|
||||
def engine_policy_arbiter(self):
|
||||
return self.owner.engine_policy_arbiter
|
||||
|
||||
def _register_request_state(
|
||||
self,
|
||||
request_id: str,
|
||||
api_mode: str,
|
||||
backend: str,
|
||||
media_type: str,
|
||||
response_streaming: bool,
|
||||
deadline_ts: float | None = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
) -> EngineRequestState:
|
||||
return self.request_registry.register(
|
||||
request_id=request_id,
|
||||
api_mode=api_mode,
|
||||
backend=backend,
|
||||
media_type=media_type,
|
||||
response_streaming=response_streaming,
|
||||
deadline_ts=deadline_ts,
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def _update_request_state(
|
||||
self,
|
||||
request_id: str,
|
||||
status: str,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.request_registry.update(request_id, status, extra)
|
||||
|
||||
def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.request_registry.merge_profile(request_id, extra)
|
||||
|
||||
def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.request_registry.complete(request_id, extra)
|
||||
|
||||
def _fail_request_state(self, request_id: str, error: str) -> None:
|
||||
self.request_registry.fail(request_id, error)
|
||||
|
||||
def _snapshot_request_registry(self) -> Dict[str, Any]:
|
||||
return self.request_registry.snapshot()
|
||||
|
||||
def _snapshot_engine_prepare_state(self) -> Dict[str, Any]:
|
||||
return self.engine_prepare_queue_owner.snapshot(max_request_ids=16)
|
||||
|
||||
def _snapshot_engine_finalize_state(self) -> Dict[str, Any]:
|
||||
return self.engine_finalize_queue_owner.snapshot(max_request_ids=16)
|
||||
|
||||
def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]:
|
||||
return self.engine_dispatch_queue_owner.snapshot(
|
||||
max_request_ids=16,
|
||||
extra={"last_policy_snapshot": dict(self.owner.engine_dispatch_last_snapshot or {})},
|
||||
)
|
||||
|
||||
def _register_engine_job(self, job: SchedulerPendingJob) -> None:
|
||||
self.engine_job_registry.register(job, keep_job=True)
|
||||
|
||||
def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
|
||||
return self.engine_job_registry.get(request_id)
|
||||
|
||||
def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
|
||||
return self.engine_job_registry.pop(request_id)
|
||||
|
||||
def _snapshot_engine_job_registry(self) -> Dict[str, Any]:
|
||||
return self.engine_job_registry.snapshot(max_request_ids=32)
|
||||
|
||||
def _is_engine_drained(self) -> bool:
|
||||
prepare_empty = self.engine_prepare_queue_owner.is_drained()
|
||||
dispatch_empty = self.engine_dispatch_queue_owner.is_drained()
|
||||
finalize_empty = self.engine_finalize_queue_owner.is_drained()
|
||||
decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs()
|
||||
job_empty = self.engine_job_registry.is_empty()
|
||||
worker_state = self.scheduler_worker.snapshot()
|
||||
return bool(
|
||||
prepare_empty
|
||||
and dispatch_empty
|
||||
and finalize_empty
|
||||
and decode_pending_empty
|
||||
and job_empty
|
||||
and self.engine_decode_runtime_owner.get_active_batch() is None
|
||||
and int(worker_state.get("prepare_inflight", 0)) <= 0
|
||||
and int(worker_state.get("finalize_inflight", 0)) <= 0
|
||||
and int(worker_state.get("finalize_pending", 0)) <= 0
|
||||
)
|
||||
|
||||
def _record_engine_job_done(self, request_id: str) -> None:
|
||||
self.engine_job_registry.mark_finished_and_remove(request_id)
|
||||
self.scheduler_worker.record_external_job_done(request_id)
|
||||
|
||||
def _complete_engine_job(
|
||||
self,
|
||||
job: SchedulerPendingJob,
|
||||
item: T2SFinishedItem,
|
||||
*,
|
||||
sample_rate: int,
|
||||
audio_data: np.ndarray,
|
||||
) -> None:
|
||||
completion_bridge = self.scheduler_worker.completion_bridge
|
||||
completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data)
|
||||
completion_bridge.complete_job(
|
||||
job,
|
||||
runtime_request_id=job.engine_request_id,
|
||||
runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate),
|
||||
on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid),
|
||||
)
|
||||
|
||||
def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None:
|
||||
if not request_ids:
|
||||
return
|
||||
completion_bridge = self.scheduler_worker.completion_bridge
|
||||
for request_id in request_ids:
|
||||
job = self._get_engine_job(request_id)
|
||||
if job is None:
|
||||
continue
|
||||
completion_bridge.fail_job(
|
||||
job,
|
||||
error=error,
|
||||
on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid),
|
||||
)
|
||||
|
||||
def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None:
|
||||
delta_ms = float(elapsed_s) * 1000.0
|
||||
for job in jobs:
|
||||
job.prefill_ms += delta_ms
|
||||
|
||||
def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
|
||||
delta_ms = float(elapsed_s) * 1000.0
|
||||
for request_id in request_ids:
|
||||
job = self._get_engine_job(request_id)
|
||||
if job is not None:
|
||||
job.merge_ms += delta_ms
|
||||
|
||||
def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
|
||||
delta_ms = float(elapsed_s) * 1000.0
|
||||
activate_request_ids: List[str] = []
|
||||
for request_id in request_ids:
|
||||
job = self._get_engine_job(request_id)
|
||||
if job is None:
|
||||
continue
|
||||
if job.decode_steps == 0:
|
||||
activate_request_ids.append(job.engine_request_id)
|
||||
job.decode_ms += delta_ms
|
||||
job.decode_steps += 1
|
||||
for engine_request_id in activate_request_ids:
|
||||
self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None)
|
||||
|
||||
def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None:
|
||||
if not items:
|
||||
return
|
||||
enqueued_at = time.perf_counter()
|
||||
tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items]
|
||||
self._enqueue_worker_finished_for_finalize(tasks)
|
||||
|
||||
def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]:
|
||||
return self.engine_decode_runtime_owner.snapshot_pending_queue_state()
|
||||
|
||||
@staticmethod
|
||||
def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]:
|
||||
return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch)
|
||||
|
||||
def _refresh_engine_decode_runtime_state(self, last_event: str) -> None:
|
||||
self.engine_decode_runtime_owner.refresh_state(last_event)
|
||||
|
||||
def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None:
|
||||
if not snapshot:
|
||||
return
|
||||
if self.scheduler_worker.is_engine_decode_control_enabled():
|
||||
return
|
||||
self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot)
|
||||
|
||||
def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]:
|
||||
return self.engine_decode_runtime_owner.snapshot_state()
|
||||
|
||||
def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]:
|
||||
return self.engine_policy_arbiter.snapshot_state()
|
||||
|
||||
def _notify_engine_arbiter(self) -> None:
|
||||
self.engine_policy_arbiter.notify()
|
||||
|
||||
def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None:
|
||||
self.engine_stage_coordinator.decode_runtime_owner.enqueue_pending_job(job)
|
||||
self._notify_engine_arbiter()
|
||||
|
||||
def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
|
||||
return self.engine_stage_coordinator.decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch)
|
||||
|
||||
def _peek_queue_age_ms(self, queue_name: str) -> float:
|
||||
return self.engine_stage_coordinator.peek_queue_age_ms(queue_name)
|
||||
|
||||
def _engine_has_pending_work(self) -> bool:
|
||||
return self.engine_stage_coordinator.has_pending_work()
|
||||
|
||||
async def _prepare_state_via_engine_gpu_queue(
|
||||
self,
|
||||
*,
|
||||
spec: SchedulerRequestSpec,
|
||||
prepare_submit_at: float,
|
||||
engine_request_id: str | None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
return await self.engine_stage_coordinator.prepare_state_via_engine_gpu_queue(
|
||||
spec=spec,
|
||||
prepare_submit_at=prepare_submit_at,
|
||||
engine_request_id=engine_request_id,
|
||||
)
|
||||
|
||||
def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
|
||||
self.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks)
|
||||
|
||||
def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
|
||||
return self.engine_stage_coordinator.take_engine_finalize_batch_nonblocking()
|
||||
|
||||
async def _enqueue_prepared_state_for_dispatch(
|
||||
self,
|
||||
*,
|
||||
state: T2SRequestState,
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None,
|
||||
done_future: asyncio.Future | None,
|
||||
engine_request_id: str | None,
|
||||
timeout_sec: float | None,
|
||||
) -> EngineDispatchTask:
|
||||
return await self.engine_stage_coordinator.enqueue_prepared_state_for_dispatch(
|
||||
state=state,
|
||||
speed_factor=speed_factor,
|
||||
sample_steps=sample_steps,
|
||||
media_type=media_type,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=done_loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=engine_request_id,
|
||||
timeout_sec=timeout_sec,
|
||||
)
|
||||
|
||||
def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
|
||||
self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed)
|
||||
|
||||
def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
|
||||
stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage()
|
||||
self.owner.engine_dispatch_last_snapshot = dict(policy_snapshot)
|
||||
return stage, reason, policy_snapshot, worker_state
|
||||
|
||||
def _run_engine_prepare_once(self) -> bool:
|
||||
return self.engine_stage_coordinator.run_engine_prepare_once()
|
||||
|
||||
def _run_engine_finalize_once(self) -> bool:
|
||||
return self.engine_stage_coordinator.run_engine_finalize_once()
|
||||
|
||||
def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
|
||||
return self.engine_stage_coordinator.run_engine_dispatch_once(policy_snapshot, worker_state)
|
||||
|
||||
def _run_engine_decode_runtime_once(self) -> bool:
|
||||
return self.engine_stage_coordinator.run_engine_decode_runtime_once()
|
||||
|
||||
def _run_engine_arbiter_loop(self) -> None:
|
||||
self.engine_stage_coordinator.run_engine_arbiter_loop()
|
||||
179
GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py
Normal file
179
GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py
Normal file
@ -0,0 +1,179 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import (
|
||||
EngineArbiterConfig,
|
||||
EngineDecodeRuntimeOwner,
|
||||
EnginePolicyArbiterController,
|
||||
EnginePolicyConfig,
|
||||
EngineRequestRegistry,
|
||||
EngineTaskQueueOwner,
|
||||
ModelRegistry,
|
||||
ReferenceRegistry,
|
||||
RuntimeStateCallbacks,
|
||||
SchedulerJobRegistry,
|
||||
)
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage import EngineStageCoordinator
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
|
||||
|
||||
|
||||
class EngineCompositionBuilder:
|
||||
def __init__(self, owner: Any) -> None:
|
||||
self.owner = owner
|
||||
|
||||
def build(self, *, max_steps: int, micro_batch_wait_ms: int) -> None:
|
||||
self._init_registries_and_locks()
|
||||
self._init_worker(max_steps=max_steps, micro_batch_wait_ms=micro_batch_wait_ms)
|
||||
self._init_policy_configs(micro_batch_wait_ms=micro_batch_wait_ms)
|
||||
self._init_runtime_owners()
|
||||
self._init_stage_coordinator()
|
||||
self._init_arbiter()
|
||||
self._init_facades()
|
||||
self._start_arbiter_thread()
|
||||
|
||||
def _init_registries_and_locks(self) -> None:
|
||||
owner = self.owner
|
||||
owner.reference_registry = ReferenceRegistry()
|
||||
owner.model_registry = ModelRegistry(
|
||||
t2s_weights_path=str(owner.tts.configs.t2s_weights_path),
|
||||
vits_weights_path=str(owner.tts.configs.vits_weights_path),
|
||||
)
|
||||
owner.request_registry = EngineRequestRegistry(
|
||||
recent_limit=max(1, int(os.environ.get("GPTSOVITS_ENGINE_RECENT_REQUEST_LIMIT", "64")))
|
||||
)
|
||||
owner.engine_job_registry = SchedulerJobRegistry(threading.Lock())
|
||||
owner.direct_tts_lock = threading.RLock()
|
||||
owner.management_lock = threading.RLock()
|
||||
owner.engine_dispatch_last_snapshot = {}
|
||||
|
||||
def _init_worker(self, *, max_steps: int, micro_batch_wait_ms: int) -> None:
|
||||
owner = self.owner
|
||||
owner.scheduler_worker = UnifiedSchedulerWorker(
|
||||
owner.tts,
|
||||
max_steps=max_steps,
|
||||
micro_batch_wait_ms=micro_batch_wait_ms,
|
||||
runtime_callbacks=RuntimeStateCallbacks(
|
||||
update=owner._update_request_state,
|
||||
complete=owner._complete_request_state,
|
||||
fail=owner._fail_request_state,
|
||||
decode_runtime_update=owner._update_engine_decode_runtime_state,
|
||||
),
|
||||
external_finalize_submit=owner._enqueue_worker_finished_for_finalize,
|
||||
)
|
||||
|
||||
def _init_policy_configs(self, *, micro_batch_wait_ms: int) -> None:
|
||||
owner = self.owner
|
||||
worker_capacity_limits = owner.scheduler_worker.get_capacity_limits()
|
||||
prepare_max_inflight = int(owner.scheduler_worker.get_prepare_max_inflight())
|
||||
owner.engine_policy_config = EnginePolicyConfig(
|
||||
enabled=owner._env_flag("GPTSOVITS_ENGINE_POLICY_ENABLE", True),
|
||||
poll_wait_ms=max(1.0, owner._env_float("GPTSOVITS_ENGINE_POLICY_POLL_WAIT_MS", float(micro_batch_wait_ms))),
|
||||
decode_backlog_soft_max=max(
|
||||
0,
|
||||
owner._env_int(
|
||||
"GPTSOVITS_ENGINE_POLICY_DECODE_BACKLOG_SOFT_MAX",
|
||||
int(worker_capacity_limits["decode_backlog_max"]),
|
||||
),
|
||||
),
|
||||
finalize_pending_soft_max=max(
|
||||
0,
|
||||
owner._env_int(
|
||||
"GPTSOVITS_ENGINE_POLICY_FINALIZE_PENDING_SOFT_MAX",
|
||||
int(worker_capacity_limits["finalize_pending_max"]),
|
||||
),
|
||||
),
|
||||
prepare_inflight_soft_max=max(
|
||||
0,
|
||||
owner._env_int("GPTSOVITS_ENGINE_POLICY_PREPARE_INFLIGHT_SOFT_MAX", prepare_max_inflight),
|
||||
),
|
||||
active_decode_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_DECODE_SOFT_MAX", 0)),
|
||||
ready_for_prefill_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_READY_FOR_PREFILL_SOFT_MAX", 0)),
|
||||
active_request_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_REQUEST_SOFT_MAX", 0)),
|
||||
)
|
||||
owner.engine_arbiter_config = EngineArbiterConfig(
|
||||
poll_wait_ms=max(1.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_POLL_WAIT_MS", float(micro_batch_wait_ms))),
|
||||
decode_burst=max(1, owner._env_int("GPTSOVITS_ENGINE_ARBITER_DECODE_BURST", 4)),
|
||||
prepare_aging_ms=max(0.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_PREPARE_AGING_MS", 10.0)),
|
||||
finalize_aging_ms=max(0.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_FINALIZE_AGING_MS", 10.0)),
|
||||
)
|
||||
|
||||
def _init_runtime_owners(self) -> None:
|
||||
owner = self.owner
|
||||
owner.engine_decode_runtime_owner = EngineDecodeRuntimeOwner(
|
||||
get_decode_runtime_counters=owner.scheduler_worker.get_decode_runtime_counters,
|
||||
get_micro_batch_wait_s=owner.scheduler_worker.get_micro_batch_wait_s,
|
||||
)
|
||||
owner.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
|
||||
owner.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
|
||||
owner.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched")
|
||||
|
||||
def _init_stage_coordinator(self) -> None:
|
||||
owner = self.owner
|
||||
owner.engine_stage_coordinator = EngineStageCoordinator(
|
||||
tts=owner.tts,
|
||||
scheduler_worker=owner.scheduler_worker,
|
||||
prepare_queue_owner=owner.engine_prepare_queue_owner,
|
||||
finalize_queue_owner=owner.engine_finalize_queue_owner,
|
||||
dispatch_queue_owner=owner.engine_dispatch_queue_owner,
|
||||
decode_runtime_owner=owner.engine_decode_runtime_owner,
|
||||
update_request_state=owner._update_request_state,
|
||||
merge_request_state_profile=owner._merge_request_state_profile,
|
||||
fail_request_state=owner._fail_request_state,
|
||||
get_engine_job=owner._get_engine_job,
|
||||
register_engine_job=owner._register_engine_job,
|
||||
fail_engine_jobs=owner._fail_engine_jobs,
|
||||
complete_engine_job=owner._complete_engine_job,
|
||||
add_engine_prefill_time=owner._add_engine_prefill_time,
|
||||
add_engine_merge_time=owner._add_engine_merge_time,
|
||||
add_engine_decode_time=owner._add_engine_decode_time,
|
||||
enqueue_engine_finished_items=owner._enqueue_engine_finished_items,
|
||||
snapshot_engine_dispatch_state=owner._snapshot_engine_dispatch_state,
|
||||
snapshot_engine_decode_runtime_state=owner._snapshot_engine_decode_runtime_state,
|
||||
)
|
||||
|
||||
def _init_arbiter(self) -> None:
|
||||
owner = self.owner
|
||||
owner.engine_policy_arbiter = EnginePolicyArbiterController(
|
||||
policy_config=owner.engine_policy_config,
|
||||
arbiter_config=owner.engine_arbiter_config,
|
||||
snapshot_request_registry=owner._snapshot_request_registry,
|
||||
get_worker_state=owner.get_scheduler_state,
|
||||
snapshot_prepare_state=owner._snapshot_engine_prepare_state,
|
||||
snapshot_finalize_state=owner._snapshot_engine_finalize_state,
|
||||
snapshot_dispatch_state=owner._snapshot_engine_dispatch_state,
|
||||
snapshot_decode_runtime_state=owner._snapshot_engine_decode_runtime_state,
|
||||
snapshot_job_registry=owner._snapshot_engine_job_registry,
|
||||
peek_queue_age_ms=owner.engine_stage_coordinator.peek_queue_age_ms,
|
||||
merge_request_state_profile=owner._merge_request_state_profile,
|
||||
)
|
||||
owner.engine_stage_coordinator.bind_arbiter(
|
||||
notify_arbiter=owner._notify_engine_arbiter,
|
||||
select_stage=owner._select_engine_stage,
|
||||
mark_arbiter_tick=lambda stage, reason, policy_allowed: owner._mark_arbiter_tick(
|
||||
stage=stage,
|
||||
reason=reason,
|
||||
policy_allowed=policy_allowed,
|
||||
),
|
||||
wait_arbiter=owner.engine_policy_arbiter.wait,
|
||||
)
|
||||
|
||||
def _init_facades(self) -> None:
|
||||
owner = self.owner
|
||||
owner.bridge_facade = EngineBridgeFacade(owner)
|
||||
owner.api_facade = EngineApiFacade(owner)
|
||||
owner.runtime_facade = EngineRuntimeFacade(owner)
|
||||
|
||||
def _start_arbiter_thread(self) -> None:
|
||||
owner = self.owner
|
||||
owner.engine_arbiter_thread = threading.Thread(
|
||||
target=owner._run_engine_arbiter_loop,
|
||||
name="unified-engine-arbiter",
|
||||
daemon=True,
|
||||
)
|
||||
owner.engine_arbiter_thread.start()
|
||||
1150
GPT_SoVITS/TTS_infer_pack/unified_engine_components.py
Normal file
1150
GPT_SoVITS/TTS_infer_pack/unified_engine_components.py
Normal file
File diff suppressed because it is too large
Load Diff
446
GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py
Normal file
446
GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py
Normal file
@ -0,0 +1,446 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, EngineDispatchTask, EngineRequestState, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerFinalizeTask, SchedulerPendingJob, SchedulerSubmitExecution
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade
|
||||
|
||||
|
||||
class EngineBridgeDelegates:
|
||||
def _register_request_state(
|
||||
self,
|
||||
request_id: str,
|
||||
api_mode: str,
|
||||
backend: str,
|
||||
media_type: str,
|
||||
response_streaming: bool,
|
||||
deadline_ts: float | None = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
) -> EngineRequestState:
|
||||
return self.bridge_facade._register_request_state(
|
||||
request_id=request_id,
|
||||
api_mode=api_mode,
|
||||
backend=backend,
|
||||
media_type=media_type,
|
||||
response_streaming=response_streaming,
|
||||
deadline_ts=deadline_ts,
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def _update_request_state(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.bridge_facade._update_request_state(request_id, status, extra)
|
||||
|
||||
def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.bridge_facade._merge_request_state_profile(request_id, extra)
|
||||
|
||||
def _snapshot_engine_prepare_state(self) -> Dict[str, Any]:
|
||||
return self.bridge_facade._snapshot_engine_prepare_state()
|
||||
|
||||
def _snapshot_engine_finalize_state(self) -> Dict[str, Any]:
|
||||
return self.bridge_facade._snapshot_engine_finalize_state()
|
||||
|
||||
def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]:
|
||||
return self.bridge_facade._snapshot_engine_dispatch_state()
|
||||
|
||||
def _register_engine_job(self, job: SchedulerPendingJob) -> None:
|
||||
self.bridge_facade._register_engine_job(job)
|
||||
|
||||
def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
|
||||
return self.bridge_facade._get_engine_job(request_id)
|
||||
|
||||
def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
|
||||
return self.bridge_facade._pop_engine_job(request_id)
|
||||
|
||||
def _snapshot_engine_job_registry(self) -> Dict[str, Any]:
|
||||
return self.bridge_facade._snapshot_engine_job_registry()
|
||||
|
||||
def _is_engine_drained(self) -> bool:
|
||||
return self.bridge_facade._is_engine_drained()
|
||||
|
||||
def _record_engine_job_done(self, request_id: str) -> None:
|
||||
self.bridge_facade._record_engine_job_done(request_id)
|
||||
|
||||
def _complete_engine_job(
|
||||
self,
|
||||
job: SchedulerPendingJob,
|
||||
item: T2SFinishedItem,
|
||||
*,
|
||||
sample_rate: int,
|
||||
audio_data: np.ndarray,
|
||||
) -> None:
|
||||
self.bridge_facade._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
|
||||
|
||||
def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None:
|
||||
self.bridge_facade._fail_engine_jobs(request_ids, error)
|
||||
|
||||
def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None:
|
||||
self.bridge_facade._add_engine_prefill_time(jobs, elapsed_s)
|
||||
|
||||
def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
|
||||
self.bridge_facade._add_engine_merge_time(request_ids, elapsed_s)
|
||||
|
||||
def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
|
||||
self.bridge_facade._add_engine_decode_time(request_ids, elapsed_s)
|
||||
|
||||
def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None:
|
||||
self.bridge_facade._enqueue_engine_finished_items(items)
|
||||
|
||||
def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]:
|
||||
return self.bridge_facade._snapshot_engine_decode_pending_queue_state()
|
||||
|
||||
@staticmethod
|
||||
def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]:
|
||||
return EngineBridgeFacade._summarize_active_batch(active_batch)
|
||||
|
||||
def _refresh_engine_decode_runtime_state(self, last_event: str) -> None:
|
||||
self.bridge_facade._refresh_engine_decode_runtime_state(last_event)
|
||||
|
||||
def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None:
|
||||
self.bridge_facade._update_engine_decode_runtime_state(snapshot)
|
||||
|
||||
def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]:
|
||||
return self.bridge_facade._snapshot_engine_decode_runtime_state()
|
||||
|
||||
def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]:
|
||||
return self.bridge_facade._snapshot_engine_arbiter_state()
|
||||
|
||||
def _notify_engine_arbiter(self) -> None:
|
||||
self.bridge_facade._notify_engine_arbiter()
|
||||
|
||||
def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None:
|
||||
self.bridge_facade._enqueue_engine_decode_pending_job(job)
|
||||
|
||||
def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
|
||||
return self.bridge_facade._take_engine_decode_pending_jobs_nonblocking(wait_for_batch)
|
||||
|
||||
def _peek_queue_age_ms(self, queue_name: str) -> float:
|
||||
return self.bridge_facade._peek_queue_age_ms(queue_name)
|
||||
|
||||
def _engine_has_pending_work(self) -> bool:
|
||||
return self.bridge_facade._engine_has_pending_work()
|
||||
|
||||
async def _prepare_state_via_engine_gpu_queue(
|
||||
self,
|
||||
*,
|
||||
spec: SchedulerRequestSpec,
|
||||
prepare_submit_at: float,
|
||||
engine_request_id: str | None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
return await self.bridge_facade._prepare_state_via_engine_gpu_queue(
|
||||
spec=spec,
|
||||
prepare_submit_at=prepare_submit_at,
|
||||
engine_request_id=engine_request_id,
|
||||
)
|
||||
|
||||
def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
|
||||
self.bridge_facade._enqueue_worker_finished_for_finalize(tasks)
|
||||
|
||||
def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
|
||||
return self.bridge_facade._take_engine_finalize_batch_nonblocking()
|
||||
|
||||
async def _enqueue_prepared_state_for_dispatch(
|
||||
self,
|
||||
*,
|
||||
state: T2SRequestState,
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None,
|
||||
done_future: asyncio.Future | None,
|
||||
engine_request_id: str | None,
|
||||
timeout_sec: float | None,
|
||||
) -> EngineDispatchTask:
|
||||
return await self.bridge_facade._enqueue_prepared_state_for_dispatch(
|
||||
state=state,
|
||||
speed_factor=speed_factor,
|
||||
sample_steps=sample_steps,
|
||||
media_type=media_type,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=done_loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=engine_request_id,
|
||||
timeout_sec=timeout_sec,
|
||||
)
|
||||
|
||||
def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
|
||||
self.bridge_facade._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed)
|
||||
|
||||
def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
|
||||
return self.bridge_facade._select_engine_stage()
|
||||
|
||||
def _run_engine_prepare_once(self) -> bool:
|
||||
return self.bridge_facade._run_engine_prepare_once()
|
||||
|
||||
def _run_engine_finalize_once(self) -> bool:
|
||||
return self.bridge_facade._run_engine_finalize_once()
|
||||
|
||||
def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
|
||||
return self.bridge_facade._run_engine_dispatch_once(policy_snapshot, worker_state)
|
||||
|
||||
def _run_engine_decode_runtime_once(self) -> bool:
|
||||
return self.bridge_facade._run_engine_decode_runtime_once()
|
||||
|
||||
def _run_engine_arbiter_loop(self) -> None:
|
||||
self.bridge_facade._run_engine_arbiter_loop()
|
||||
|
||||
def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.bridge_facade._complete_request_state(request_id, extra)
|
||||
|
||||
def _fail_request_state(self, request_id: str, error: str) -> None:
|
||||
self.bridge_facade._fail_request_state(request_id, error)
|
||||
|
||||
def _snapshot_request_registry(self) -> Dict[str, Any]:
|
||||
return self.bridge_facade._snapshot_request_registry()
|
||||
|
||||
|
||||
class EngineApiDelegates:
|
||||
def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]:
|
||||
return self.api_facade._collect_request_summaries(request_ids)
|
||||
|
||||
def _has_active_request(self, request_id: str) -> bool:
|
||||
return self.api_facade._has_active_request(request_id)
|
||||
|
||||
@staticmethod
|
||||
def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return EngineApiFacade._build_request_meta(payload)
|
||||
|
||||
@staticmethod
|
||||
def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float:
|
||||
return EngineApiFacade._sum_profile_field(items, key)
|
||||
|
||||
def _build_direct_segment_trace(
|
||||
self,
|
||||
segment_texts: Sequence[str],
|
||||
prepare_profiles: Sequence[Dict[str, Any]],
|
||||
worker_profiles: Sequence[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
return self.api_facade._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles)
|
||||
|
||||
def _build_direct_scheduler_profile(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
return self.api_facade._build_direct_scheduler_profile(**kwargs)
|
||||
|
||||
def _build_legacy_direct_profile(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
return self.api_facade._build_legacy_direct_profile(**kwargs)
|
||||
|
||||
def _build_scheduler_submit_profile(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
return self.api_facade._build_scheduler_submit_profile(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _format_ms_header(value: Any) -> str:
|
||||
return EngineApiFacade._format_ms_header(value)
|
||||
|
||||
def _build_scheduler_submit_headers(
|
||||
self,
|
||||
*,
|
||||
request_id: str,
|
||||
media_type: str,
|
||||
sample_rate: int,
|
||||
profile: Dict[str, Any],
|
||||
) -> Dict[str, str]:
|
||||
return self.api_facade._build_scheduler_submit_headers(
|
||||
request_id=request_id,
|
||||
media_type=media_type,
|
||||
sample_rate=sample_rate,
|
||||
profile=profile,
|
||||
)
|
||||
|
||||
def _build_scheduler_debug_request_profile(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
return self.api_facade._build_scheduler_debug_request_profile(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _build_scheduler_debug_batch_profile(**kwargs: Any) -> Dict[str, Any]:
|
||||
return EngineApiFacade._build_scheduler_debug_batch_profile(**kwargs)
|
||||
|
||||
def _normalize_lang(self, value: str | None) -> str | None:
|
||||
return self.api_facade._normalize_lang(value)
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]:
|
||||
return EngineApiFacade._aggregate_numeric_dicts(items)
|
||||
|
||||
def _apply_default_reference(self, req: dict) -> dict:
|
||||
return self.api_facade._apply_default_reference(req)
|
||||
|
||||
def check_params(self, req: dict) -> Optional[str]:
|
||||
return self.api_facade.check_params(req)
|
||||
|
||||
@staticmethod
|
||||
def _base_request_defaults() -> Dict[str, Any]:
|
||||
return EngineApiFacade._base_request_defaults()
|
||||
|
||||
def _normalize_engine_request(
|
||||
self,
|
||||
payload: dict | NormalizedEngineRequest,
|
||||
*,
|
||||
request_id: str | None = None,
|
||||
normalize_streaming: bool = False,
|
||||
error_prefix: str = "request 参数非法: ",
|
||||
) -> NormalizedEngineRequest:
|
||||
return self.api_facade._normalize_engine_request(
|
||||
payload,
|
||||
request_id=request_id,
|
||||
normalize_streaming=normalize_streaming,
|
||||
error_prefix=error_prefix,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_streaming_mode(req: dict) -> dict:
|
||||
return EngineApiFacade._normalize_streaming_mode(req)
|
||||
|
||||
@staticmethod
|
||||
def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool:
|
||||
return EngineApiFacade._is_aux_ref_enabled(aux_ref_audio_paths)
|
||||
|
||||
def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]:
|
||||
return self.api_facade._select_direct_backend(normalized)
|
||||
|
||||
def _iter_legacy_direct_tts_bytes(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
) -> Generator[bytes, None, None]:
|
||||
return self.api_facade._iter_legacy_direct_tts_bytes(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
|
||||
def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool:
|
||||
return self.api_facade._should_use_scheduler_backend_for_direct(req)
|
||||
|
||||
def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]:
|
||||
return self.api_facade._segment_direct_text(normalized)
|
||||
|
||||
def _build_segment_request(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
request_id: str,
|
||||
text: str,
|
||||
) -> NormalizedEngineRequest:
|
||||
return self.api_facade._build_segment_request(normalized, request_id=request_id, text=text)
|
||||
|
||||
async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution:
|
||||
return await self.api_facade._run_direct_tts_via_scheduler(normalized)
|
||||
|
||||
def _run_legacy_direct_tts_blocking(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
) -> DirectTTSExecution:
|
||||
return self.api_facade._run_legacy_direct_tts_blocking(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
|
||||
async def _run_direct_tts_via_legacy_backend(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
) -> DirectTTSExecution:
|
||||
return await self.api_facade._run_direct_tts_via_legacy_backend(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
|
||||
async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution:
|
||||
return await self.api_facade.run_direct_tts_async(req)
|
||||
|
||||
def run_direct_tts(self, req: dict) -> DirectTTSExecution:
|
||||
return self.api_facade.run_direct_tts(req)
|
||||
|
||||
def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]:
|
||||
return self.api_facade.build_scheduler_request_specs(request_items)
|
||||
|
||||
def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec:
|
||||
return self.api_facade.build_scheduler_submit_spec(payload)
|
||||
|
||||
@staticmethod
|
||||
def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]:
|
||||
return EngineApiFacade.summarize_scheduler_states(states)
|
||||
|
||||
@staticmethod
|
||||
def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
|
||||
return EngineApiFacade.summarize_scheduler_finished(items)
|
||||
|
||||
async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
|
||||
return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed)
|
||||
|
||||
async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution:
|
||||
return await self.api_facade.run_scheduler_submit(payload)
|
||||
|
||||
|
||||
class EngineRuntimeDelegates:
|
||||
@staticmethod
|
||||
def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None:
|
||||
return EngineRuntimeFacade._safe_component_snapshot(component)
|
||||
|
||||
def _build_stage_counters(
|
||||
self,
|
||||
request_registry: Dict[str, Any],
|
||||
worker_state: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
return self.runtime_facade._build_stage_counters(request_registry, worker_state)
|
||||
|
||||
def _build_engine_policy_snapshot(
|
||||
self,
|
||||
request_registry: Dict[str, Any],
|
||||
worker_state: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
return self.runtime_facade._build_engine_policy_snapshot(request_registry, worker_state)
|
||||
|
||||
async def _wait_for_engine_policy_admission(
|
||||
self,
|
||||
*,
|
||||
request_id: str | None,
|
||||
timeout_sec: float | None,
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
return await self.engine_policy_arbiter.wait_for_policy_admission(
|
||||
request_id=request_id,
|
||||
timeout_sec=timeout_sec,
|
||||
)
|
||||
|
||||
def _build_stage_summary(
|
||||
self,
|
||||
request_registry: Dict[str, Any],
|
||||
worker_state: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
return self.runtime_facade._build_stage_summary(request_registry, worker_state)
|
||||
|
||||
def get_scheduler_state(self) -> dict:
|
||||
return self.runtime_facade.get_scheduler_state()
|
||||
|
||||
def get_runtime_state(self) -> dict:
|
||||
return self.runtime_facade.get_runtime_state()
|
||||
|
||||
def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None:
|
||||
self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec)
|
||||
|
||||
def set_refer_audio(self, refer_audio_path: str | None) -> dict:
|
||||
return self.runtime_facade.set_refer_audio(refer_audio_path)
|
||||
|
||||
def set_gpt_weights(self, weights_path: str) -> dict:
|
||||
return self.runtime_facade.set_gpt_weights(weights_path)
|
||||
|
||||
def set_sovits_weights(self, weights_path: str) -> dict:
|
||||
return self.runtime_facade.set_sovits_weights(weights_path)
|
||||
|
||||
def handle_control(self, command: str) -> None:
|
||||
self.runtime_facade.handle_control(command)
|
||||
198
GPT_SoVITS/TTS_infer_pack/unified_engine_runtime.py
Normal file
198
GPT_SoVITS/TTS_infer_pack/unified_engine_runtime.py
Normal file
@ -0,0 +1,198 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class EngineRuntimeFacade:
|
||||
def __init__(self, owner: Any) -> None:
|
||||
self.owner = owner
|
||||
|
||||
@property
|
||||
def tts(self):
|
||||
return self.owner.tts
|
||||
|
||||
@property
|
||||
def reference_registry(self):
|
||||
return self.owner.reference_registry
|
||||
|
||||
@property
|
||||
def model_registry(self):
|
||||
return self.owner.model_registry
|
||||
|
||||
@property
|
||||
def scheduler_worker(self):
|
||||
return self.owner.scheduler_worker
|
||||
|
||||
@property
|
||||
def engine_decode_runtime_owner(self):
|
||||
return self.owner.engine_decode_runtime_owner
|
||||
|
||||
@property
|
||||
def engine_policy_arbiter(self):
|
||||
return self.owner.engine_policy_arbiter
|
||||
|
||||
@property
|
||||
def management_lock(self):
|
||||
return self.owner.management_lock
|
||||
|
||||
@property
|
||||
def direct_tts_lock(self):
|
||||
return self.owner.direct_tts_lock
|
||||
|
||||
@property
|
||||
def control_callbacks(self):
|
||||
return self.owner.control_callbacks
|
||||
|
||||
@staticmethod
|
||||
def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None:
|
||||
if component is None or not hasattr(component, "snapshot"):
|
||||
return None
|
||||
try:
|
||||
return dict(component.snapshot())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _build_stage_counters(
|
||||
self,
|
||||
request_registry: Dict[str, Any],
|
||||
worker_state: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
return self.engine_policy_arbiter.build_stage_counters(request_registry, worker_state)
|
||||
|
||||
def _build_engine_policy_snapshot(
|
||||
self,
|
||||
request_registry: Dict[str, Any],
|
||||
worker_state: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
return self.engine_policy_arbiter.build_policy_snapshot(request_registry, worker_state)
|
||||
|
||||
def _build_stage_summary(
|
||||
self,
|
||||
request_registry: Dict[str, Any],
|
||||
worker_state: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
counters = self._build_stage_counters(request_registry, worker_state)
|
||||
bert_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_bert_batch_worker", None))
|
||||
ref_semantic_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_ref_semantic_batch_worker", None))
|
||||
text_preprocessor_state = self._safe_component_snapshot(getattr(self.tts, "text_preprocessor", None))
|
||||
|
||||
return {
|
||||
**counters,
|
||||
"engine_drained": bool(self.owner._is_engine_drained()),
|
||||
"admission_config": {
|
||||
"decode_backlog_max": int(worker_state.get("decode_backlog_max", 0)),
|
||||
"finalize_pending_max": int(worker_state.get("finalize_pending_max", 0)),
|
||||
},
|
||||
"engine_policy": self._build_engine_policy_snapshot(request_registry, worker_state),
|
||||
"engine_arbiter_state": self.owner._snapshot_engine_arbiter_state(),
|
||||
"engine_decode_runtime_state": self.owner._snapshot_engine_decode_runtime_state(),
|
||||
"engine_job_registry": self.owner._snapshot_engine_job_registry(),
|
||||
"engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(),
|
||||
"engine_prepare_state": self.owner._snapshot_engine_prepare_state(),
|
||||
"engine_finalize_state": self.owner._snapshot_engine_finalize_state(),
|
||||
"engine_dispatcher_state": self.owner._snapshot_engine_dispatch_state(),
|
||||
"active_batch": dict(worker_state.get("active_batch") or {}),
|
||||
"prepare_state": dict(worker_state.get("prepare_state") or {}),
|
||||
"bert_batch_worker_state": bert_worker_state,
|
||||
"ref_semantic_worker_state": ref_semantic_worker_state,
|
||||
"text_preprocessor_state": text_preprocessor_state,
|
||||
}
|
||||
|
||||
def get_scheduler_state(self) -> dict:
|
||||
return self.scheduler_worker.snapshot()
|
||||
|
||||
def get_runtime_state(self) -> dict:
|
||||
model_state = self.model_registry.snapshot()
|
||||
default_ref = self.reference_registry.get_default()
|
||||
scheduler_state = self.get_scheduler_state()
|
||||
request_registry = self.owner._snapshot_request_registry()
|
||||
engine_policy = self._build_engine_policy_snapshot(request_registry, scheduler_state)
|
||||
engine_arbiter_state = self.owner._snapshot_engine_arbiter_state()
|
||||
engine_decode_runtime_state = self.owner._snapshot_engine_decode_runtime_state()
|
||||
engine_job_registry = self.owner._snapshot_engine_job_registry()
|
||||
engine_prepare_state = self.owner._snapshot_engine_prepare_state()
|
||||
engine_finalize_state = self.owner._snapshot_engine_finalize_state()
|
||||
engine_dispatcher_state = self.owner._snapshot_engine_dispatch_state()
|
||||
engine_drained = self.owner._is_engine_drained()
|
||||
return {
|
||||
"message": "success",
|
||||
"default_reference": {
|
||||
"ref_audio_path": default_ref.ref_audio_path,
|
||||
"updated_at": default_ref.updated_at,
|
||||
},
|
||||
"model_registry": {
|
||||
"generation": model_state.generation,
|
||||
"t2s_generation": model_state.t2s_generation,
|
||||
"vits_generation": model_state.vits_generation,
|
||||
"t2s_weights_path": model_state.t2s_weights_path,
|
||||
"vits_weights_path": model_state.vits_weights_path,
|
||||
"updated_at": model_state.updated_at,
|
||||
},
|
||||
"worker_state": scheduler_state,
|
||||
"engine_policy": engine_policy,
|
||||
"engine_arbiter_state": engine_arbiter_state,
|
||||
"engine_decode_runtime_state": engine_decode_runtime_state,
|
||||
"engine_job_registry": engine_job_registry,
|
||||
"engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(),
|
||||
"engine_prepare_state": engine_prepare_state,
|
||||
"engine_finalize_state": engine_finalize_state,
|
||||
"engine_dispatcher_state": engine_dispatcher_state,
|
||||
"engine_drained": bool(engine_drained),
|
||||
"request_registry": request_registry,
|
||||
"stage_summary": self._build_stage_summary(request_registry, scheduler_state),
|
||||
}
|
||||
|
||||
def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None:
|
||||
if not self.scheduler_worker.wait_until_idle(timeout_sec=timeout_sec):
|
||||
raise TimeoutError("scheduler worker did not drain before model reload")
|
||||
|
||||
def set_refer_audio(self, refer_audio_path: str | None) -> dict:
|
||||
if refer_audio_path in [None, ""]:
|
||||
state = self.reference_registry.clear()
|
||||
return {"message": "success", "default_ref_audio_path": state.ref_audio_path}
|
||||
if not os.path.exists(str(refer_audio_path)):
|
||||
raise FileNotFoundError(f"{refer_audio_path} not exists")
|
||||
with self.management_lock:
|
||||
with self.direct_tts_lock:
|
||||
self.tts.set_ref_audio(str(refer_audio_path))
|
||||
state = self.reference_registry.set_default(str(refer_audio_path))
|
||||
return {"message": "success", "default_ref_audio_path": state.ref_audio_path}
|
||||
|
||||
def set_gpt_weights(self, weights_path: str) -> dict:
|
||||
if weights_path in ["", None]:
|
||||
raise ValueError("gpt weight path is required")
|
||||
with self.management_lock:
|
||||
self._wait_for_safe_reload()
|
||||
with self.direct_tts_lock:
|
||||
self.tts.init_t2s_weights(weights_path)
|
||||
self.tts.refresh_runtime_components()
|
||||
state = self.model_registry.mark_t2s_reload(str(weights_path))
|
||||
return {"message": "success", "t2s_generation": state.t2s_generation, "generation": state.generation}
|
||||
|
||||
def set_sovits_weights(self, weights_path: str) -> dict:
|
||||
if weights_path in ["", None]:
|
||||
raise ValueError("sovits weight path is required")
|
||||
with self.management_lock:
|
||||
self._wait_for_safe_reload()
|
||||
with self.direct_tts_lock:
|
||||
self.tts.init_vits_weights(weights_path)
|
||||
self.tts.refresh_runtime_components()
|
||||
state = self.model_registry.mark_vits_reload(str(weights_path))
|
||||
return {"message": "success", "vits_generation": state.vits_generation, "generation": state.generation}
|
||||
|
||||
def handle_control(self, command: str) -> None:
|
||||
if command == "restart":
|
||||
if self.control_callbacks.restart is None:
|
||||
os.execl(sys.executable, sys.executable, *sys.argv)
|
||||
self.control_callbacks.restart()
|
||||
return
|
||||
if command == "exit":
|
||||
if self.control_callbacks.exit is None:
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
return
|
||||
self.control_callbacks.exit()
|
||||
return
|
||||
raise ValueError(f"unsupported command: {command}")
|
||||
420
GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py
Normal file
420
GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py
Normal file
@ -0,0 +1,420 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import (
|
||||
EngineDecodeRuntimeOwner,
|
||||
EngineDispatchTask,
|
||||
EngineGpuPrepareTask,
|
||||
EngineStatus,
|
||||
EngineTaskQueueOwner,
|
||||
SchedulerFinalizeTask,
|
||||
SchedulerPendingJob,
|
||||
)
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
|
||||
|
||||
|
||||
class EngineStageCoordinator:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tts: TTS,
|
||||
scheduler_worker: UnifiedSchedulerWorker,
|
||||
prepare_queue_owner: EngineTaskQueueOwner,
|
||||
finalize_queue_owner: EngineTaskQueueOwner,
|
||||
dispatch_queue_owner: EngineTaskQueueOwner,
|
||||
decode_runtime_owner: EngineDecodeRuntimeOwner,
|
||||
update_request_state: Callable[[str, str, Optional[Dict[str, Any]]], None],
|
||||
merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None],
|
||||
fail_request_state: Callable[[str, str], None],
|
||||
get_engine_job: Callable[[str], SchedulerPendingJob | None],
|
||||
register_engine_job: Callable[[SchedulerPendingJob], None],
|
||||
fail_engine_jobs: Callable[[List[str], str], None],
|
||||
complete_engine_job: Callable[..., None],
|
||||
add_engine_prefill_time: Callable[[List[SchedulerPendingJob], float], None],
|
||||
add_engine_merge_time: Callable[[List[str], float], None],
|
||||
add_engine_decode_time: Callable[[List[str], float], None],
|
||||
enqueue_engine_finished_items: Callable[[List[T2SFinishedItem]], None],
|
||||
snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]],
|
||||
snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]],
|
||||
) -> None:
|
||||
self.tts = tts
|
||||
self.scheduler_worker = scheduler_worker
|
||||
self.prepare_queue_owner = prepare_queue_owner
|
||||
self.finalize_queue_owner = finalize_queue_owner
|
||||
self.dispatch_queue_owner = dispatch_queue_owner
|
||||
self.decode_runtime_owner = decode_runtime_owner
|
||||
self.update_request_state = update_request_state
|
||||
self.merge_request_state_profile = merge_request_state_profile
|
||||
self.fail_request_state = fail_request_state
|
||||
self.get_engine_job = get_engine_job
|
||||
self.register_engine_job = register_engine_job
|
||||
self.fail_engine_jobs = fail_engine_jobs
|
||||
self.complete_engine_job = complete_engine_job
|
||||
self.add_engine_prefill_time = add_engine_prefill_time
|
||||
self.add_engine_merge_time = add_engine_merge_time
|
||||
self.add_engine_decode_time = add_engine_decode_time
|
||||
self.enqueue_engine_finished_items = enqueue_engine_finished_items
|
||||
self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state
|
||||
self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state
|
||||
self._notify_arbiter: Callable[[], None] | None = None
|
||||
self._select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]] | None = None
|
||||
self._mark_arbiter_tick: Callable[[str, str, bool], None] | None = None
|
||||
self._wait_arbiter: Callable[[], None] | None = None
|
||||
|
||||
def bind_arbiter(
|
||||
self,
|
||||
*,
|
||||
notify_arbiter: Callable[[], None],
|
||||
select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]],
|
||||
mark_arbiter_tick: Callable[[str, str, bool], None],
|
||||
wait_arbiter: Callable[[], None],
|
||||
) -> None:
|
||||
self._notify_arbiter = notify_arbiter
|
||||
self._select_stage = select_stage
|
||||
self._mark_arbiter_tick = mark_arbiter_tick
|
||||
self._wait_arbiter = wait_arbiter
|
||||
|
||||
def notify_arbiter(self) -> None:
|
||||
if self._notify_arbiter is not None:
|
||||
self._notify_arbiter()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None:
|
||||
if future.done():
|
||||
return
|
||||
future.set_exception(error)
|
||||
|
||||
def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None:
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _resolve_prepare_future(
|
||||
future: asyncio.Future,
|
||||
payload: tuple[T2SRequestState, float, float],
|
||||
) -> None:
|
||||
if future.done():
|
||||
return
|
||||
future.set_result(payload)
|
||||
|
||||
def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None:
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def _notify_prepare_result(
|
||||
self,
|
||||
task: EngineGpuPrepareTask,
|
||||
payload: tuple[T2SRequestState, float, float],
|
||||
) -> None:
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
async def prepare_state_via_engine_gpu_queue(
|
||||
self,
|
||||
*,
|
||||
spec,
|
||||
prepare_submit_at: float,
|
||||
engine_request_id: str | None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at)
|
||||
if engine_request_id not in [None, ""]:
|
||||
self.update_request_state(
|
||||
str(engine_request_id),
|
||||
EngineStatus.GPU_PREPARING,
|
||||
{
|
||||
"prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms),
|
||||
"prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms),
|
||||
"text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms),
|
||||
"text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms),
|
||||
},
|
||||
)
|
||||
loop = asyncio.get_running_loop()
|
||||
done_future = loop.create_future()
|
||||
task = EngineGpuPrepareTask(
|
||||
request_id=spec.request_id,
|
||||
cpu_stage=cpu_stage,
|
||||
done_loop=loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=engine_request_id or spec.request_id,
|
||||
enqueue_time=time.perf_counter(),
|
||||
)
|
||||
self.prepare_queue_owner.enqueue(task)
|
||||
self.notify_arbiter()
|
||||
state, prepare_exec_started_at, prepare_exec_finished_at = await done_future
|
||||
return state, prepare_exec_started_at, prepare_exec_finished_at
|
||||
|
||||
def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
|
||||
if not tasks:
|
||||
return
|
||||
for task in tasks:
|
||||
job = self.get_engine_job(task.request_id)
|
||||
if job is not None:
|
||||
self.update_request_state(
|
||||
job.engine_request_id,
|
||||
EngineStatus.READY_FOR_FINALIZE,
|
||||
{
|
||||
"finish_reason": task.item.finish_reason,
|
||||
"semantic_len": int(task.item.semantic_tokens.shape[0]),
|
||||
"finish_idx": int(task.item.finish_idx),
|
||||
},
|
||||
)
|
||||
self.finalize_queue_owner.enqueue_many(tasks)
|
||||
self.notify_arbiter()
|
||||
|
||||
def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
|
||||
finalize_policy = self.scheduler_worker.get_finalize_batch_policy()
|
||||
return self.finalize_queue_owner.take_finalize_batch(
|
||||
finalize_mode=str(finalize_policy.get("finalize_mode", "async")),
|
||||
batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)),
|
||||
batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)),
|
||||
use_vocoder=bool(self.tts.configs.use_vocoder),
|
||||
)
|
||||
|
||||
async def enqueue_prepared_state_for_dispatch(
|
||||
self,
|
||||
*,
|
||||
state: T2SRequestState,
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None,
|
||||
done_future: asyncio.Future | None,
|
||||
engine_request_id: str | None,
|
||||
timeout_sec: float | None,
|
||||
) -> EngineDispatchTask:
|
||||
task = EngineDispatchTask(
|
||||
request_id=state.request_id,
|
||||
state=state,
|
||||
speed_factor=float(speed_factor),
|
||||
sample_steps=int(sample_steps),
|
||||
media_type=media_type,
|
||||
prepare_wall_ms=float(prepare_wall_ms),
|
||||
prepare_profile_total_ms=float(prepare_profile_total_ms),
|
||||
done_loop=done_loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=engine_request_id or state.request_id,
|
||||
timeout_sec=timeout_sec,
|
||||
enqueue_time=time.perf_counter(),
|
||||
)
|
||||
self.dispatch_queue_owner.enqueue(task)
|
||||
self.notify_arbiter()
|
||||
self.merge_request_state_profile(
|
||||
task.engine_request_id or task.request_id,
|
||||
{
|
||||
"engine_dispatch_queue_depth_on_enqueue": int(
|
||||
self.snapshot_engine_dispatch_state()["waiting_count"]
|
||||
),
|
||||
},
|
||||
)
|
||||
return task
|
||||
|
||||
def peek_queue_age_ms(self, queue_name: str) -> float:
|
||||
if queue_name == "prepare":
|
||||
return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time")
|
||||
if queue_name == "finalize":
|
||||
return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time")
|
||||
if queue_name == "decode_runtime_pending":
|
||||
return self.decode_runtime_owner.pending_age_ms()
|
||||
return self.dispatch_queue_owner.peek_oldest_age_ms("enqueue_time")
|
||||
|
||||
def has_pending_work(self) -> bool:
|
||||
if self.scheduler_worker.is_engine_decode_control_enabled():
|
||||
if self.decode_runtime_owner.has_pending_jobs():
|
||||
return True
|
||||
if self.scheduler_worker.is_engine_decode_control_enabled() and self.snapshot_engine_decode_runtime_state().get(
|
||||
"active_request_count", 0
|
||||
) > 0:
|
||||
return True
|
||||
if self.prepare_queue_owner.has_items():
|
||||
return True
|
||||
if self.finalize_queue_owner.has_items():
|
||||
return True
|
||||
return self.dispatch_queue_owner.has_items()
|
||||
|
||||
def run_engine_prepare_once(self) -> bool:
|
||||
task = self.prepare_queue_owner.pop_left()
|
||||
if task is None:
|
||||
return False
|
||||
queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0)
|
||||
try:
|
||||
state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run(
|
||||
self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage)
|
||||
)
|
||||
state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms)
|
||||
if task.engine_request_id not in [None, ""]:
|
||||
self.merge_request_state_profile(
|
||||
str(task.engine_request_id),
|
||||
{"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)},
|
||||
)
|
||||
self.prepare_queue_owner.mark_completed(1)
|
||||
self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at))
|
||||
return True
|
||||
except Exception as exc:
|
||||
task.error = str(exc)
|
||||
self.fail_request_state(task.engine_request_id or task.request_id, str(exc))
|
||||
self._notify_prepare_error(task, exc)
|
||||
return True
|
||||
|
||||
def run_engine_finalize_once(self) -> bool:
|
||||
tasks = self.take_engine_finalize_batch_nonblocking()
|
||||
if not tasks:
|
||||
return False
|
||||
self.scheduler_worker.begin_finalize_execution(len(tasks))
|
||||
try:
|
||||
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
|
||||
for task in tasks:
|
||||
job = self.get_engine_job(task.request_id)
|
||||
if job is None:
|
||||
continue
|
||||
jobs_and_items.append((job, task.item))
|
||||
if not jobs_and_items:
|
||||
return False
|
||||
now = time.perf_counter()
|
||||
for task in tasks:
|
||||
job = self.get_engine_job(task.request_id)
|
||||
if job is not None:
|
||||
job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0)
|
||||
for job, item in jobs_and_items:
|
||||
self.update_request_state(
|
||||
job.engine_request_id,
|
||||
EngineStatus.FINALIZING,
|
||||
{
|
||||
"finish_reason": item.finish_reason,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
},
|
||||
)
|
||||
synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items)
|
||||
for job, _ in jobs_and_items:
|
||||
job.synth_ms += float(synth_ms)
|
||||
for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results):
|
||||
self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
|
||||
except Exception as exc:
|
||||
self.fail_engine_jobs([task.request_id for task in tasks], str(exc))
|
||||
finally:
|
||||
self.scheduler_worker.end_finalize_execution(len(tasks))
|
||||
self.finalize_queue_owner.mark_completed(len(tasks), notify=True)
|
||||
return True
|
||||
|
||||
def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
|
||||
if not bool(policy_snapshot.get("allowed", True)):
|
||||
return False
|
||||
dispatch_task = self.dispatch_queue_owner.pop_left()
|
||||
if dispatch_task is None:
|
||||
return False
|
||||
dispatched_at = time.perf_counter()
|
||||
dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0)
|
||||
dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms)
|
||||
dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms)
|
||||
dispatch_task.engine_policy_snapshot = dict(policy_snapshot)
|
||||
try:
|
||||
worker_job = self.scheduler_worker.submit(
|
||||
state=dispatch_task.state,
|
||||
speed_factor=dispatch_task.speed_factor,
|
||||
sample_steps=dispatch_task.sample_steps,
|
||||
media_type=dispatch_task.media_type,
|
||||
prepare_wall_ms=dispatch_task.prepare_wall_ms,
|
||||
prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms,
|
||||
done_loop=dispatch_task.done_loop,
|
||||
done_future=dispatch_task.done_future,
|
||||
engine_request_id=dispatch_task.engine_request_id,
|
||||
timeout_sec=dispatch_task.timeout_sec,
|
||||
skip_capacity_wait=True,
|
||||
admission_wait_ms_override=0.0,
|
||||
admission_snapshot_override=dict(worker_state),
|
||||
engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms,
|
||||
engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms,
|
||||
enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(),
|
||||
)
|
||||
dispatch_task.worker_job = worker_job
|
||||
self.register_engine_job(worker_job)
|
||||
if self.scheduler_worker.is_engine_decode_control_enabled():
|
||||
self.decode_runtime_owner.enqueue_pending_job(worker_job)
|
||||
self.notify_arbiter()
|
||||
self.dispatch_queue_owner.mark_completed(1)
|
||||
return True
|
||||
except Exception as exc:
|
||||
dispatch_task.error = str(exc)
|
||||
self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc))
|
||||
self._notify_dispatch_error(dispatch_task, exc)
|
||||
return True
|
||||
|
||||
def run_engine_decode_runtime_once(self) -> bool:
|
||||
if not self.scheduler_worker.is_engine_decode_control_enabled():
|
||||
return False
|
||||
runtime_state = self.snapshot_engine_decode_runtime_state()
|
||||
pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking(
|
||||
wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0
|
||||
)
|
||||
result = self.scheduler_worker.execute_decode_cycle(
|
||||
pending_jobs=pending_jobs,
|
||||
active_batch=self.decode_runtime_owner.get_active_batch(),
|
||||
external_bookkeeping=True,
|
||||
)
|
||||
prefill_phase = dict(result.get("prefill_phase") or {})
|
||||
if prefill_phase.get("error"):
|
||||
self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error")))
|
||||
else:
|
||||
prefill_jobs = list(prefill_phase.get("pending_jobs") or [])
|
||||
self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0)))
|
||||
self.add_engine_merge_time(
|
||||
[] if result.get("active_batch") is None else list(result["active_batch"].request_ids),
|
||||
float(prefill_phase.get("merge_elapsed_s", 0.0)),
|
||||
)
|
||||
self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or []))
|
||||
decode_phase = dict(result.get("decode_phase") or {})
|
||||
if decode_phase.get("error"):
|
||||
self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error")))
|
||||
else:
|
||||
self.add_engine_decode_time(
|
||||
list(decode_phase.get("request_ids") or []),
|
||||
float(decode_phase.get("decode_elapsed_s", 0.0)),
|
||||
)
|
||||
self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or []))
|
||||
self.decode_runtime_owner.set_active_batch(result.get("active_batch"))
|
||||
if result.get("executed", False):
|
||||
self.decode_runtime_owner.refresh_state("engine_decode_cycle")
|
||||
return bool(result.get("executed", False))
|
||||
|
||||
def run_engine_arbiter_loop(self) -> None:
|
||||
if self._select_stage is None or self._mark_arbiter_tick is None or self._wait_arbiter is None:
|
||||
raise RuntimeError("arbiter callbacks are not bound")
|
||||
while True:
|
||||
if not self.has_pending_work():
|
||||
self._mark_arbiter_tick("idle", "no_pending_work", True)
|
||||
self._wait_arbiter()
|
||||
continue
|
||||
stage, reason, policy_snapshot, worker_state = self._select_stage()
|
||||
policy_allowed = bool(policy_snapshot.get("allowed", True))
|
||||
executed = False
|
||||
if stage == "prepare":
|
||||
executed = self.run_engine_prepare_once()
|
||||
elif stage == "finalize":
|
||||
executed = self.run_engine_finalize_once()
|
||||
elif stage == "decode_dispatch":
|
||||
executed = self.run_engine_dispatch_once(policy_snapshot, worker_state)
|
||||
elif stage == "decode_runtime":
|
||||
executed = self.run_engine_decode_runtime_once()
|
||||
if not executed:
|
||||
self._mark_arbiter_tick("idle", f"{stage}_not_ready", policy_allowed)
|
||||
self._wait_arbiter()
|
||||
continue
|
||||
self._mark_arbiter_tick(stage, reason, policy_allowed)
|
||||
1510
GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py
Normal file
1510
GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user