From 800f01790e243be3184681932d9747b90762b9d9 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Wed, 11 Mar 2026 17:58:20 +0800 Subject: [PATCH] Refactor EngineApiFacade and EngineApiDelegates for improved method naming and structure Rename several methods in EngineApiFacade to follow a consistent private naming convention, enhancing code clarity. Update EngineApiDelegates to remove redundant method definitions, streamlining the interface. Introduce EnginePublicInterface to encapsulate public API methods, improving organization and maintainability of the TTS system. Additionally, update the EngineCompositionBuilder to use the new scheduler worker state retrieval method. --- GPT_SoVITS/TTS_infer_pack/unified_engine.py | 3 +- .../TTS_infer_pack/unified_engine_api.py | 18 +++---- .../TTS_infer_pack/unified_engine_builder.py | 2 +- .../unified_engine_delegates.py | 45 ---------------- .../TTS_infer_pack/unified_engine_public.py | 53 +++++++++++++++++++ 5 files changed, 65 insertions(+), 56 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_public.py diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine.py b/GPT_SoVITS/TTS_infer_pack/unified_engine.py index 24e1c98b..a6faaddb 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine.py @@ -7,9 +7,10 @@ from GPT_SoVITS.TTS_infer_pack.TTS import TTS from GPT_SoVITS.TTS_infer_pack.unified_engine_builder import EngineCompositionBuilder from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeControlCallbacks from GPT_SoVITS.TTS_infer_pack.unified_engine_delegates import EngineApiDelegates, EngineBridgeDelegates, EngineRuntimeDelegates +from GPT_SoVITS.TTS_infer_pack.unified_engine_public import EngineCompatInterface, EnginePublicInterface -class UnifiedTTSEngine(EngineBridgeDelegates, EngineApiDelegates, EngineRuntimeDelegates): +class UnifiedTTSEngine(EnginePublicInterface, EngineCompatInterface, EngineBridgeDelegates, EngineApiDelegates, EngineRuntimeDelegates): @staticmethod def _env_flag(name: str, default: bool) -> bool: value = os.environ.get(name) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py index bcc8bf0d..ca76252d 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py @@ -826,7 +826,7 @@ class EngineApiFacade: request_id=f"{request_id}_seg_{segment_index:03d}", text=segment_text, ) - segment_specs.append(self.build_scheduler_submit_spec(segment_request)) + segment_specs.append(self._build_scheduler_submit_spec(segment_request)) prepared_items = await asyncio.gather( *[ @@ -1131,7 +1131,7 @@ class EngineApiFacade: fallback_reason="sync_direct_compat", ) - def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: + def _build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: specs: List[SchedulerRequestSpec] = [] for index, payload in enumerate(request_items): normalized = self._normalize_engine_request( @@ -1142,7 +1142,7 @@ class EngineApiFacade: specs.append(normalized.to_scheduler_spec()) return specs - def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: + def _build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: normalized = self._normalize_engine_request( payload, request_id=( @@ -1154,7 +1154,7 @@ class EngineApiFacade: return normalized.to_scheduler_spec() @staticmethod - def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: + def _summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: return [ { "request_id": state.request_id, @@ -1169,7 +1169,7 @@ class EngineApiFacade: ] @staticmethod - def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: + def _summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: return [ { "request_id": item.request_id, @@ -1183,7 +1183,7 @@ class EngineApiFacade: async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: request_start = time.perf_counter() set_scheduler_seed(seed) - specs = self.build_scheduler_request_specs(request_items) + specs = self._build_scheduler_request_specs(request_items) request_ids = [spec.request_id for spec in specs] for spec in specs: self._register_request_state( @@ -1270,8 +1270,8 @@ class EngineApiFacade: request_total_ms=request_total_ms, finished_items=finished, ), - "requests": self.summarize_scheduler_states(states), - "finished": self.summarize_scheduler_finished(finished), + "requests": self._summarize_scheduler_states(states), + "finished": self._summarize_scheduler_finished(finished), "request_profiles": request_profiles, "request_traces": self._collect_request_summaries(request_ids), } @@ -1284,7 +1284,7 @@ class EngineApiFacade: payload, request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"), ) - spec = self.build_scheduler_submit_spec(normalized) + spec = self._build_scheduler_submit_spec(normalized) deadline_ts = None timeout_sec = normalized.timeout_sec if timeout_sec is not None: diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py index 45178b1f..0e93c442 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py @@ -143,7 +143,7 @@ class EngineCompositionBuilder: policy_config=owner.engine_policy_config, arbiter_config=owner.engine_arbiter_config, snapshot_request_registry=owner._snapshot_request_registry, - get_worker_state=owner.get_scheduler_state, + get_worker_state=owner.scheduler_worker.snapshot, snapshot_prepare_state=owner._snapshot_engine_prepare_state, snapshot_finalize_state=owner._snapshot_engine_finalize_state, snapshot_dispatch_state=owner._snapshot_engine_dispatch_state, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py index 7dbbd5bd..f68d3ede 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py @@ -360,33 +360,6 @@ class EngineApiDelegates: fallback_reason=fallback_reason, ) - async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: - return await self.api_facade.run_direct_tts_async(req) - - def run_direct_tts(self, req: dict) -> DirectTTSExecution: - return self.api_facade.run_direct_tts(req) - - def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: - return self.api_facade.build_scheduler_request_specs(request_items) - - def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: - return self.api_facade.build_scheduler_submit_spec(payload) - - @staticmethod - def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: - return EngineApiFacade.summarize_scheduler_states(states) - - @staticmethod - def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: - return EngineApiFacade.summarize_scheduler_finished(items) - - async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: - return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed) - - async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: - return await self.api_facade.run_scheduler_submit(payload) - - class EngineRuntimeDelegates: @staticmethod def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: @@ -424,23 +397,5 @@ class EngineRuntimeDelegates: ) -> Dict[str, Any]: return self.runtime_facade._build_stage_summary(request_registry, worker_state) - def get_scheduler_state(self) -> dict: - return self.runtime_facade.get_scheduler_state() - - def get_runtime_state(self) -> dict: - return self.runtime_facade.get_runtime_state() - def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec) - - def set_refer_audio(self, refer_audio_path: str | None) -> dict: - return self.runtime_facade.set_refer_audio(refer_audio_path) - - def set_gpt_weights(self, weights_path: str) -> dict: - return self.runtime_facade.set_gpt_weights(weights_path) - - def set_sovits_weights(self, weights_path: str) -> dict: - return self.runtime_facade.set_sovits_weights(weights_path) - - def handle_control(self, command: str) -> None: - self.runtime_facade.handle_control(command) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_public.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_public.py new file mode 100644 index 00000000..fbe88b85 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_public.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, SchedulerDebugExecution, SchedulerSubmitExecution + + +class EnginePublicInterface: + PUBLIC_API_METHODS = ( + "run_direct_tts_async", + "run_scheduler_submit", + "run_scheduler_debug", + "get_runtime_state", + "set_refer_audio", + "set_gpt_weights", + "set_sovits_weights", + "handle_control", + ) + + async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: + return await self.api_facade.run_direct_tts_async(req) + + async def run_scheduler_debug(self, request_items: list[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: + return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed) + + async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: + return await self.api_facade.run_scheduler_submit(payload) + + def get_runtime_state(self) -> dict: + return self.runtime_facade.get_runtime_state() + + def set_refer_audio(self, refer_audio_path: str | None) -> dict: + return self.runtime_facade.set_refer_audio(refer_audio_path) + + def set_gpt_weights(self, weights_path: str) -> dict: + return self.runtime_facade.set_gpt_weights(weights_path) + + def set_sovits_weights(self, weights_path: str) -> dict: + return self.runtime_facade.set_sovits_weights(weights_path) + + def handle_control(self, command: str) -> None: + self.runtime_facade.handle_control(command) + + +class EngineCompatInterface: + COMPAT_API_METHODS = ( + "run_direct_tts", + "get_scheduler_state", + ) + + def run_direct_tts(self, req: dict) -> DirectTTSExecution: + return self.api_facade.run_direct_tts(req) + + def get_scheduler_state(self) -> dict: + return self.runtime_facade.get_scheduler_state()