GPT-SoVITS/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py
baicai-1145 800f01790e Refactor EngineApiFacade and EngineApiDelegates for improved method naming and structure
Rename several methods in EngineApiFacade to follow a consistent private naming convention, enhancing code clarity. Update EngineApiDelegates to remove redundant method definitions, streamlining the interface. Introduce EnginePublicInterface to encapsulate public API methods, improving organization and maintainability of the TTS system. Additionally, update the EngineCompositionBuilder to use the new scheduler worker state retrieval method.
2026-03-11 17:58:20 +08:00

180 lines
8.6 KiB
Python

from __future__ import annotations
import os
import threading
from typing import Any
from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import (
EngineArbiterConfig,
EngineDecodeRuntimeOwner,
EnginePolicyArbiterController,
EnginePolicyConfig,
EngineRequestRegistry,
EngineTaskQueueOwner,
ModelRegistry,
ReferenceRegistry,
RuntimeStateCallbacks,
SchedulerJobRegistry,
)
from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage import EngineStageCoordinator
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
class EngineCompositionBuilder:
def __init__(self, owner: Any) -> None:
self.owner = owner
def build(self, *, max_steps: int, micro_batch_wait_ms: int) -> None:
self._init_registries_and_locks()
self._init_worker(max_steps=max_steps, micro_batch_wait_ms=micro_batch_wait_ms)
self._init_policy_configs(micro_batch_wait_ms=micro_batch_wait_ms)
self._init_runtime_owners()
self._init_stage_coordinator()
self._init_arbiter()
self._init_facades()
self._start_arbiter_thread()
def _init_registries_and_locks(self) -> None:
owner = self.owner
owner.reference_registry = ReferenceRegistry()
owner.model_registry = ModelRegistry(
t2s_weights_path=str(owner.tts.configs.t2s_weights_path),
vits_weights_path=str(owner.tts.configs.vits_weights_path),
)
owner.request_registry = EngineRequestRegistry(
recent_limit=max(1, int(os.environ.get("GPTSOVITS_ENGINE_RECENT_REQUEST_LIMIT", "64")))
)
owner.engine_job_registry = SchedulerJobRegistry(threading.Lock())
owner.direct_tts_lock = threading.RLock()
owner.management_lock = threading.RLock()
owner.engine_dispatch_last_snapshot = {}
def _init_worker(self, *, max_steps: int, micro_batch_wait_ms: int) -> None:
owner = self.owner
owner.scheduler_worker = UnifiedSchedulerWorker(
owner.tts,
max_steps=max_steps,
micro_batch_wait_ms=micro_batch_wait_ms,
runtime_callbacks=RuntimeStateCallbacks(
update=owner._update_request_state,
complete=owner._complete_request_state,
fail=owner._fail_request_state,
decode_runtime_update=owner._update_engine_decode_runtime_state,
),
external_finalize_submit=owner._enqueue_worker_finished_for_finalize,
)
def _init_policy_configs(self, *, micro_batch_wait_ms: int) -> None:
owner = self.owner
worker_capacity_limits = owner.scheduler_worker.get_capacity_limits()
prepare_max_inflight = int(owner.scheduler_worker.get_prepare_max_inflight())
owner.engine_policy_config = EnginePolicyConfig(
enabled=owner._env_flag("GPTSOVITS_ENGINE_POLICY_ENABLE", True),
poll_wait_ms=max(1.0, owner._env_float("GPTSOVITS_ENGINE_POLICY_POLL_WAIT_MS", float(micro_batch_wait_ms))),
decode_backlog_soft_max=max(
0,
owner._env_int(
"GPTSOVITS_ENGINE_POLICY_DECODE_BACKLOG_SOFT_MAX",
int(worker_capacity_limits["decode_backlog_max"]),
),
),
finalize_pending_soft_max=max(
0,
owner._env_int(
"GPTSOVITS_ENGINE_POLICY_FINALIZE_PENDING_SOFT_MAX",
int(worker_capacity_limits["finalize_pending_max"]),
),
),
prepare_inflight_soft_max=max(
0,
owner._env_int("GPTSOVITS_ENGINE_POLICY_PREPARE_INFLIGHT_SOFT_MAX", prepare_max_inflight),
),
active_decode_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_DECODE_SOFT_MAX", 0)),
ready_for_prefill_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_READY_FOR_PREFILL_SOFT_MAX", 0)),
active_request_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_REQUEST_SOFT_MAX", 0)),
)
owner.engine_arbiter_config = EngineArbiterConfig(
poll_wait_ms=max(1.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_POLL_WAIT_MS", float(micro_batch_wait_ms))),
decode_burst=max(1, owner._env_int("GPTSOVITS_ENGINE_ARBITER_DECODE_BURST", 4)),
prepare_aging_ms=max(0.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_PREPARE_AGING_MS", 10.0)),
finalize_aging_ms=max(0.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_FINALIZE_AGING_MS", 10.0)),
)
def _init_runtime_owners(self) -> None:
owner = self.owner
owner.engine_decode_runtime_owner = EngineDecodeRuntimeOwner(
get_decode_runtime_counters=owner.scheduler_worker.get_decode_runtime_counters,
get_micro_batch_wait_s=owner.scheduler_worker.get_micro_batch_wait_s,
)
owner.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
owner.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
owner.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched")
def _init_stage_coordinator(self) -> None:
owner = self.owner
owner.engine_stage_coordinator = EngineStageCoordinator(
tts=owner.tts,
scheduler_worker=owner.scheduler_worker,
prepare_queue_owner=owner.engine_prepare_queue_owner,
finalize_queue_owner=owner.engine_finalize_queue_owner,
dispatch_queue_owner=owner.engine_dispatch_queue_owner,
decode_runtime_owner=owner.engine_decode_runtime_owner,
update_request_state=owner._update_request_state,
merge_request_state_profile=owner._merge_request_state_profile,
fail_request_state=owner._fail_request_state,
get_engine_job=owner._get_engine_job,
register_engine_job=owner._register_engine_job,
fail_engine_jobs=owner._fail_engine_jobs,
complete_engine_job=owner._complete_engine_job,
add_engine_prefill_time=owner._add_engine_prefill_time,
add_engine_merge_time=owner._add_engine_merge_time,
add_engine_decode_time=owner._add_engine_decode_time,
enqueue_engine_finished_items=owner._enqueue_engine_finished_items,
snapshot_engine_dispatch_state=owner._snapshot_engine_dispatch_state,
snapshot_engine_decode_runtime_state=owner._snapshot_engine_decode_runtime_state,
)
def _init_arbiter(self) -> None:
owner = self.owner
owner.engine_policy_arbiter = EnginePolicyArbiterController(
policy_config=owner.engine_policy_config,
arbiter_config=owner.engine_arbiter_config,
snapshot_request_registry=owner._snapshot_request_registry,
get_worker_state=owner.scheduler_worker.snapshot,
snapshot_prepare_state=owner._snapshot_engine_prepare_state,
snapshot_finalize_state=owner._snapshot_engine_finalize_state,
snapshot_dispatch_state=owner._snapshot_engine_dispatch_state,
snapshot_decode_runtime_state=owner._snapshot_engine_decode_runtime_state,
snapshot_job_registry=owner._snapshot_engine_job_registry,
peek_queue_age_ms=owner.engine_stage_coordinator.peek_queue_age_ms,
merge_request_state_profile=owner._merge_request_state_profile,
)
owner.engine_stage_coordinator.bind_arbiter(
notify_arbiter=owner._notify_engine_arbiter,
select_stage=owner._select_engine_stage,
mark_arbiter_tick=lambda stage, reason, policy_allowed: owner._mark_arbiter_tick(
stage=stage,
reason=reason,
policy_allowed=policy_allowed,
),
wait_arbiter=owner.engine_policy_arbiter.wait,
)
def _init_facades(self) -> None:
owner = self.owner
owner.bridge_facade = EngineBridgeFacade(owner)
owner.api_facade = EngineApiFacade(owner)
owner.runtime_facade = EngineRuntimeFacade(owner)
def _start_arbiter_thread(self) -> None:
owner = self.owner
owner.engine_arbiter_thread = threading.Thread(
target=owner._run_engine_arbiter_loop,
name="unified-engine-arbiter",
daemon=True,
)
owner.engine_arbiter_thread.start()