diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine.py b/GPT_SoVITS/TTS_infer_pack/unified_engine.py index 24e1c98b..a6faaddb 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine.py @@ -7,9 +7,10 @@ from GPT_SoVITS.TTS_infer_pack.TTS import TTS from GPT_SoVITS.TTS_infer_pack.unified_engine_builder import EngineCompositionBuilder from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeControlCallbacks from GPT_SoVITS.TTS_infer_pack.unified_engine_delegates import EngineApiDelegates, EngineBridgeDelegates, EngineRuntimeDelegates +from GPT_SoVITS.TTS_infer_pack.unified_engine_public import EngineCompatInterface, EnginePublicInterface -class UnifiedTTSEngine(EngineBridgeDelegates, EngineApiDelegates, EngineRuntimeDelegates): +class UnifiedTTSEngine(EnginePublicInterface, EngineCompatInterface, EngineBridgeDelegates, EngineApiDelegates, EngineRuntimeDelegates): @staticmethod def _env_flag(name: str, default: bool) -> bool: value = os.environ.get(name) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py index bcc8bf0d..ca76252d 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py @@ -826,7 +826,7 @@ class EngineApiFacade: request_id=f"{request_id}_seg_{segment_index:03d}", text=segment_text, ) - segment_specs.append(self.build_scheduler_submit_spec(segment_request)) + segment_specs.append(self._build_scheduler_submit_spec(segment_request)) prepared_items = await asyncio.gather( *[ @@ -1131,7 +1131,7 @@ class EngineApiFacade: fallback_reason="sync_direct_compat", ) - def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: + def _build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: specs: List[SchedulerRequestSpec] = [] for index, payload in enumerate(request_items): normalized = self._normalize_engine_request( @@ -1142,7 +1142,7 @@ class EngineApiFacade: specs.append(normalized.to_scheduler_spec()) return specs - def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: + def _build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: normalized = self._normalize_engine_request( payload, request_id=( @@ -1154,7 +1154,7 @@ class EngineApiFacade: return normalized.to_scheduler_spec() @staticmethod - def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: + def _summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: return [ { "request_id": state.request_id, @@ -1169,7 +1169,7 @@ class EngineApiFacade: ] @staticmethod - def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: + def _summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: return [ { "request_id": item.request_id, @@ -1183,7 +1183,7 @@ class EngineApiFacade: async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: request_start = time.perf_counter() set_scheduler_seed(seed) - specs = self.build_scheduler_request_specs(request_items) + specs = self._build_scheduler_request_specs(request_items) request_ids = [spec.request_id for spec in specs] for spec in specs: self._register_request_state( @@ -1270,8 +1270,8 @@ class EngineApiFacade: request_total_ms=request_total_ms, finished_items=finished, ), - "requests": self.summarize_scheduler_states(states), - "finished": self.summarize_scheduler_finished(finished), + "requests": self._summarize_scheduler_states(states), + "finished": self._summarize_scheduler_finished(finished), "request_profiles": request_profiles, "request_traces": self._collect_request_summaries(request_ids), } @@ -1284,7 +1284,7 @@ class EngineApiFacade: payload, request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"), ) - spec = self.build_scheduler_submit_spec(normalized) + spec = self._build_scheduler_submit_spec(normalized) deadline_ts = None timeout_sec = normalized.timeout_sec if timeout_sec is not None: diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py index 45178b1f..0e93c442 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py @@ -143,7 +143,7 @@ class EngineCompositionBuilder: policy_config=owner.engine_policy_config, arbiter_config=owner.engine_arbiter_config, snapshot_request_registry=owner._snapshot_request_registry, - get_worker_state=owner.get_scheduler_state, + get_worker_state=owner.scheduler_worker.snapshot, snapshot_prepare_state=owner._snapshot_engine_prepare_state, snapshot_finalize_state=owner._snapshot_engine_finalize_state, snapshot_dispatch_state=owner._snapshot_engine_dispatch_state, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py index 7dbbd5bd..f68d3ede 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py @@ -360,33 +360,6 @@ class EngineApiDelegates: fallback_reason=fallback_reason, ) - async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: - return await self.api_facade.run_direct_tts_async(req) - - def run_direct_tts(self, req: dict) -> DirectTTSExecution: - return self.api_facade.run_direct_tts(req) - - def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: - return self.api_facade.build_scheduler_request_specs(request_items) - - def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: - return self.api_facade.build_scheduler_submit_spec(payload) - - @staticmethod - def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: - return EngineApiFacade.summarize_scheduler_states(states) - - @staticmethod - def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: - return EngineApiFacade.summarize_scheduler_finished(items) - - async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: - return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed) - - async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: - return await self.api_facade.run_scheduler_submit(payload) - - class EngineRuntimeDelegates: @staticmethod def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: @@ -424,23 +397,5 @@ class EngineRuntimeDelegates: ) -> Dict[str, Any]: return self.runtime_facade._build_stage_summary(request_registry, worker_state) - def get_scheduler_state(self) -> dict: - return self.runtime_facade.get_scheduler_state() - - def get_runtime_state(self) -> dict: - return self.runtime_facade.get_runtime_state() - def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec) - - def set_refer_audio(self, refer_audio_path: str | None) -> dict: - return self.runtime_facade.set_refer_audio(refer_audio_path) - - def set_gpt_weights(self, weights_path: str) -> dict: - return self.runtime_facade.set_gpt_weights(weights_path) - - def set_sovits_weights(self, weights_path: str) -> dict: - return self.runtime_facade.set_sovits_weights(weights_path) - - def handle_control(self, command: str) -> None: - self.runtime_facade.handle_control(command) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_public.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_public.py new file mode 100644 index 00000000..fbe88b85 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_public.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, SchedulerDebugExecution, SchedulerSubmitExecution + + +class EnginePublicInterface: + PUBLIC_API_METHODS = ( + "run_direct_tts_async", + "run_scheduler_submit", + "run_scheduler_debug", + "get_runtime_state", + "set_refer_audio", + "set_gpt_weights", + "set_sovits_weights", + "handle_control", + ) + + async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: + return await self.api_facade.run_direct_tts_async(req) + + async def run_scheduler_debug(self, request_items: list[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: + return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed) + + async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: + return await self.api_facade.run_scheduler_submit(payload) + + def get_runtime_state(self) -> dict: + return self.runtime_facade.get_runtime_state() + + def set_refer_audio(self, refer_audio_path: str | None) -> dict: + return self.runtime_facade.set_refer_audio(refer_audio_path) + + def set_gpt_weights(self, weights_path: str) -> dict: + return self.runtime_facade.set_gpt_weights(weights_path) + + def set_sovits_weights(self, weights_path: str) -> dict: + return self.runtime_facade.set_sovits_weights(weights_path) + + def handle_control(self, command: str) -> None: + self.runtime_facade.handle_control(command) + + +class EngineCompatInterface: + COMPAT_API_METHODS = ( + "run_direct_tts", + "get_scheduler_state", + ) + + def run_direct_tts(self, req: dict) -> DirectTTSExecution: + return self.api_facade.run_direct_tts(req) + + def get_scheduler_state(self) -> dict: + return self.runtime_facade.get_scheduler_state()