Refactor EngineApiFacade and EngineApiDelegates for improved method naming and structure

Rename several methods in EngineApiFacade to follow a consistent private naming convention, enhancing code clarity. Update EngineApiDelegates to remove redundant method definitions, streamlining the interface. Introduce EnginePublicInterface to encapsulate public API methods, improving organization and maintainability of the TTS system. Additionally, update the EngineCompositionBuilder to use the new scheduler worker state retrieval method.
This commit is contained in:
baicai-1145 2026-03-11 17:58:20 +08:00
parent d1ec7d9e54
commit 800f01790e
5 changed files with 65 additions and 56 deletions

View File

@ -7,9 +7,10 @@ from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.unified_engine_builder import EngineCompositionBuilder
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeControlCallbacks
from GPT_SoVITS.TTS_infer_pack.unified_engine_delegates import EngineApiDelegates, EngineBridgeDelegates, EngineRuntimeDelegates
from GPT_SoVITS.TTS_infer_pack.unified_engine_public import EngineCompatInterface, EnginePublicInterface
class UnifiedTTSEngine(EngineBridgeDelegates, EngineApiDelegates, EngineRuntimeDelegates):
class UnifiedTTSEngine(EnginePublicInterface, EngineCompatInterface, EngineBridgeDelegates, EngineApiDelegates, EngineRuntimeDelegates):
@staticmethod
def _env_flag(name: str, default: bool) -> bool:
value = os.environ.get(name)

View File

@ -826,7 +826,7 @@ class EngineApiFacade:
request_id=f"{request_id}_seg_{segment_index:03d}",
text=segment_text,
)
segment_specs.append(self.build_scheduler_submit_spec(segment_request))
segment_specs.append(self._build_scheduler_submit_spec(segment_request))
prepared_items = await asyncio.gather(
*[
@ -1131,7 +1131,7 @@ class EngineApiFacade:
fallback_reason="sync_direct_compat",
)
def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]:
def _build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]:
specs: List[SchedulerRequestSpec] = []
for index, payload in enumerate(request_items):
normalized = self._normalize_engine_request(
@ -1142,7 +1142,7 @@ class EngineApiFacade:
specs.append(normalized.to_scheduler_spec())
return specs
def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec:
def _build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec:
normalized = self._normalize_engine_request(
payload,
request_id=(
@ -1154,7 +1154,7 @@ class EngineApiFacade:
return normalized.to_scheduler_spec()
@staticmethod
def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]:
def _summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]:
return [
{
"request_id": state.request_id,
@ -1169,7 +1169,7 @@ class EngineApiFacade:
]
@staticmethod
def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
def _summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
return [
{
"request_id": item.request_id,
@ -1183,7 +1183,7 @@ class EngineApiFacade:
async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
request_start = time.perf_counter()
set_scheduler_seed(seed)
specs = self.build_scheduler_request_specs(request_items)
specs = self._build_scheduler_request_specs(request_items)
request_ids = [spec.request_id for spec in specs]
for spec in specs:
self._register_request_state(
@ -1270,8 +1270,8 @@ class EngineApiFacade:
request_total_ms=request_total_ms,
finished_items=finished,
),
"requests": self.summarize_scheduler_states(states),
"finished": self.summarize_scheduler_finished(finished),
"requests": self._summarize_scheduler_states(states),
"finished": self._summarize_scheduler_finished(finished),
"request_profiles": request_profiles,
"request_traces": self._collect_request_summaries(request_ids),
}
@ -1284,7 +1284,7 @@ class EngineApiFacade:
payload,
request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"),
)
spec = self.build_scheduler_submit_spec(normalized)
spec = self._build_scheduler_submit_spec(normalized)
deadline_ts = None
timeout_sec = normalized.timeout_sec
if timeout_sec is not None:

View File

@ -143,7 +143,7 @@ class EngineCompositionBuilder:
policy_config=owner.engine_policy_config,
arbiter_config=owner.engine_arbiter_config,
snapshot_request_registry=owner._snapshot_request_registry,
get_worker_state=owner.get_scheduler_state,
get_worker_state=owner.scheduler_worker.snapshot,
snapshot_prepare_state=owner._snapshot_engine_prepare_state,
snapshot_finalize_state=owner._snapshot_engine_finalize_state,
snapshot_dispatch_state=owner._snapshot_engine_dispatch_state,

View File

@ -360,33 +360,6 @@ class EngineApiDelegates:
fallback_reason=fallback_reason,
)
async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution:
return await self.api_facade.run_direct_tts_async(req)
def run_direct_tts(self, req: dict) -> DirectTTSExecution:
return self.api_facade.run_direct_tts(req)
def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]:
return self.api_facade.build_scheduler_request_specs(request_items)
def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec:
return self.api_facade.build_scheduler_submit_spec(payload)
@staticmethod
def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]:
return EngineApiFacade.summarize_scheduler_states(states)
@staticmethod
def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
return EngineApiFacade.summarize_scheduler_finished(items)
async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed)
async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution:
return await self.api_facade.run_scheduler_submit(payload)
class EngineRuntimeDelegates:
@staticmethod
def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None:
@ -424,23 +397,5 @@ class EngineRuntimeDelegates:
) -> Dict[str, Any]:
return self.runtime_facade._build_stage_summary(request_registry, worker_state)
def get_scheduler_state(self) -> dict:
return self.runtime_facade.get_scheduler_state()
def get_runtime_state(self) -> dict:
return self.runtime_facade.get_runtime_state()
def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None:
self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec)
def set_refer_audio(self, refer_audio_path: str | None) -> dict:
return self.runtime_facade.set_refer_audio(refer_audio_path)
def set_gpt_weights(self, weights_path: str) -> dict:
return self.runtime_facade.set_gpt_weights(weights_path)
def set_sovits_weights(self, weights_path: str) -> dict:
return self.runtime_facade.set_sovits_weights(weights_path)
def handle_control(self, command: str) -> None:
self.runtime_facade.handle_control(command)

View File

@ -0,0 +1,53 @@
from __future__ import annotations
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, SchedulerDebugExecution, SchedulerSubmitExecution
class EnginePublicInterface:
PUBLIC_API_METHODS = (
"run_direct_tts_async",
"run_scheduler_submit",
"run_scheduler_debug",
"get_runtime_state",
"set_refer_audio",
"set_gpt_weights",
"set_sovits_weights",
"handle_control",
)
async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution:
return await self.api_facade.run_direct_tts_async(req)
async def run_scheduler_debug(self, request_items: list[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed)
async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution:
return await self.api_facade.run_scheduler_submit(payload)
def get_runtime_state(self) -> dict:
return self.runtime_facade.get_runtime_state()
def set_refer_audio(self, refer_audio_path: str | None) -> dict:
return self.runtime_facade.set_refer_audio(refer_audio_path)
def set_gpt_weights(self, weights_path: str) -> dict:
return self.runtime_facade.set_gpt_weights(weights_path)
def set_sovits_weights(self, weights_path: str) -> dict:
return self.runtime_facade.set_sovits_weights(weights_path)
def handle_control(self, command: str) -> None:
self.runtime_facade.handle_control(command)
class EngineCompatInterface:
COMPAT_API_METHODS = (
"run_direct_tts",
"get_scheduler_state",
)
def run_direct_tts(self, req: dict) -> DirectTTSExecution:
return self.api_facade.run_direct_tts(req)
def get_scheduler_state(self) -> dict:
return self.runtime_facade.get_scheduler_state()