mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-26 14:28:13 +08:00
Refactor EngineApiFacade and EngineApiDelegates for improved method naming and structure
Rename several methods in EngineApiFacade to follow a consistent private naming convention, enhancing code clarity. Update EngineApiDelegates to remove redundant method definitions, streamlining the interface. Introduce EnginePublicInterface to encapsulate public API methods, improving organization and maintainability of the TTS system. Additionally, update the EngineCompositionBuilder to use the new scheduler worker state retrieval method.
This commit is contained in:
parent
d1ec7d9e54
commit
800f01790e
@ -7,9 +7,10 @@ from GPT_SoVITS.TTS_infer_pack.TTS import TTS
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_builder import EngineCompositionBuilder
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeControlCallbacks
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_delegates import EngineApiDelegates, EngineBridgeDelegates, EngineRuntimeDelegates
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_public import EngineCompatInterface, EnginePublicInterface
|
||||
|
||||
|
||||
class UnifiedTTSEngine(EngineBridgeDelegates, EngineApiDelegates, EngineRuntimeDelegates):
|
||||
class UnifiedTTSEngine(EnginePublicInterface, EngineCompatInterface, EngineBridgeDelegates, EngineApiDelegates, EngineRuntimeDelegates):
|
||||
@staticmethod
|
||||
def _env_flag(name: str, default: bool) -> bool:
|
||||
value = os.environ.get(name)
|
||||
|
||||
@ -826,7 +826,7 @@ class EngineApiFacade:
|
||||
request_id=f"{request_id}_seg_{segment_index:03d}",
|
||||
text=segment_text,
|
||||
)
|
||||
segment_specs.append(self.build_scheduler_submit_spec(segment_request))
|
||||
segment_specs.append(self._build_scheduler_submit_spec(segment_request))
|
||||
|
||||
prepared_items = await asyncio.gather(
|
||||
*[
|
||||
@ -1131,7 +1131,7 @@ class EngineApiFacade:
|
||||
fallback_reason="sync_direct_compat",
|
||||
)
|
||||
|
||||
def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]:
|
||||
def _build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]:
|
||||
specs: List[SchedulerRequestSpec] = []
|
||||
for index, payload in enumerate(request_items):
|
||||
normalized = self._normalize_engine_request(
|
||||
@ -1142,7 +1142,7 @@ class EngineApiFacade:
|
||||
specs.append(normalized.to_scheduler_spec())
|
||||
return specs
|
||||
|
||||
def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec:
|
||||
def _build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec:
|
||||
normalized = self._normalize_engine_request(
|
||||
payload,
|
||||
request_id=(
|
||||
@ -1154,7 +1154,7 @@ class EngineApiFacade:
|
||||
return normalized.to_scheduler_spec()
|
||||
|
||||
@staticmethod
|
||||
def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]:
|
||||
def _summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"request_id": state.request_id,
|
||||
@ -1169,7 +1169,7 @@ class EngineApiFacade:
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
|
||||
def _summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"request_id": item.request_id,
|
||||
@ -1183,7 +1183,7 @@ class EngineApiFacade:
|
||||
async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
|
||||
request_start = time.perf_counter()
|
||||
set_scheduler_seed(seed)
|
||||
specs = self.build_scheduler_request_specs(request_items)
|
||||
specs = self._build_scheduler_request_specs(request_items)
|
||||
request_ids = [spec.request_id for spec in specs]
|
||||
for spec in specs:
|
||||
self._register_request_state(
|
||||
@ -1270,8 +1270,8 @@ class EngineApiFacade:
|
||||
request_total_ms=request_total_ms,
|
||||
finished_items=finished,
|
||||
),
|
||||
"requests": self.summarize_scheduler_states(states),
|
||||
"finished": self.summarize_scheduler_finished(finished),
|
||||
"requests": self._summarize_scheduler_states(states),
|
||||
"finished": self._summarize_scheduler_finished(finished),
|
||||
"request_profiles": request_profiles,
|
||||
"request_traces": self._collect_request_summaries(request_ids),
|
||||
}
|
||||
@ -1284,7 +1284,7 @@ class EngineApiFacade:
|
||||
payload,
|
||||
request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"),
|
||||
)
|
||||
spec = self.build_scheduler_submit_spec(normalized)
|
||||
spec = self._build_scheduler_submit_spec(normalized)
|
||||
deadline_ts = None
|
||||
timeout_sec = normalized.timeout_sec
|
||||
if timeout_sec is not None:
|
||||
|
||||
@ -143,7 +143,7 @@ class EngineCompositionBuilder:
|
||||
policy_config=owner.engine_policy_config,
|
||||
arbiter_config=owner.engine_arbiter_config,
|
||||
snapshot_request_registry=owner._snapshot_request_registry,
|
||||
get_worker_state=owner.get_scheduler_state,
|
||||
get_worker_state=owner.scheduler_worker.snapshot,
|
||||
snapshot_prepare_state=owner._snapshot_engine_prepare_state,
|
||||
snapshot_finalize_state=owner._snapshot_engine_finalize_state,
|
||||
snapshot_dispatch_state=owner._snapshot_engine_dispatch_state,
|
||||
|
||||
@ -360,33 +360,6 @@ class EngineApiDelegates:
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
|
||||
async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution:
|
||||
return await self.api_facade.run_direct_tts_async(req)
|
||||
|
||||
def run_direct_tts(self, req: dict) -> DirectTTSExecution:
|
||||
return self.api_facade.run_direct_tts(req)
|
||||
|
||||
def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]:
|
||||
return self.api_facade.build_scheduler_request_specs(request_items)
|
||||
|
||||
def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec:
|
||||
return self.api_facade.build_scheduler_submit_spec(payload)
|
||||
|
||||
@staticmethod
|
||||
def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]:
|
||||
return EngineApiFacade.summarize_scheduler_states(states)
|
||||
|
||||
@staticmethod
|
||||
def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
|
||||
return EngineApiFacade.summarize_scheduler_finished(items)
|
||||
|
||||
async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
|
||||
return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed)
|
||||
|
||||
async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution:
|
||||
return await self.api_facade.run_scheduler_submit(payload)
|
||||
|
||||
|
||||
class EngineRuntimeDelegates:
|
||||
@staticmethod
|
||||
def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None:
|
||||
@ -424,23 +397,5 @@ class EngineRuntimeDelegates:
|
||||
) -> Dict[str, Any]:
|
||||
return self.runtime_facade._build_stage_summary(request_registry, worker_state)
|
||||
|
||||
def get_scheduler_state(self) -> dict:
|
||||
return self.runtime_facade.get_scheduler_state()
|
||||
|
||||
def get_runtime_state(self) -> dict:
|
||||
return self.runtime_facade.get_runtime_state()
|
||||
|
||||
def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None:
|
||||
self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec)
|
||||
|
||||
def set_refer_audio(self, refer_audio_path: str | None) -> dict:
|
||||
return self.runtime_facade.set_refer_audio(refer_audio_path)
|
||||
|
||||
def set_gpt_weights(self, weights_path: str) -> dict:
|
||||
return self.runtime_facade.set_gpt_weights(weights_path)
|
||||
|
||||
def set_sovits_weights(self, weights_path: str) -> dict:
|
||||
return self.runtime_facade.set_sovits_weights(weights_path)
|
||||
|
||||
def handle_control(self, command: str) -> None:
|
||||
self.runtime_facade.handle_control(command)
|
||||
|
||||
53
GPT_SoVITS/TTS_infer_pack/unified_engine_public.py
Normal file
53
GPT_SoVITS/TTS_infer_pack/unified_engine_public.py
Normal file
@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, SchedulerDebugExecution, SchedulerSubmitExecution
|
||||
|
||||
|
||||
class EnginePublicInterface:
|
||||
PUBLIC_API_METHODS = (
|
||||
"run_direct_tts_async",
|
||||
"run_scheduler_submit",
|
||||
"run_scheduler_debug",
|
||||
"get_runtime_state",
|
||||
"set_refer_audio",
|
||||
"set_gpt_weights",
|
||||
"set_sovits_weights",
|
||||
"handle_control",
|
||||
)
|
||||
|
||||
async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution:
|
||||
return await self.api_facade.run_direct_tts_async(req)
|
||||
|
||||
async def run_scheduler_debug(self, request_items: list[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
|
||||
return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed)
|
||||
|
||||
async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution:
|
||||
return await self.api_facade.run_scheduler_submit(payload)
|
||||
|
||||
def get_runtime_state(self) -> dict:
|
||||
return self.runtime_facade.get_runtime_state()
|
||||
|
||||
def set_refer_audio(self, refer_audio_path: str | None) -> dict:
|
||||
return self.runtime_facade.set_refer_audio(refer_audio_path)
|
||||
|
||||
def set_gpt_weights(self, weights_path: str) -> dict:
|
||||
return self.runtime_facade.set_gpt_weights(weights_path)
|
||||
|
||||
def set_sovits_weights(self, weights_path: str) -> dict:
|
||||
return self.runtime_facade.set_sovits_weights(weights_path)
|
||||
|
||||
def handle_control(self, command: str) -> None:
|
||||
self.runtime_facade.handle_control(command)
|
||||
|
||||
|
||||
class EngineCompatInterface:
|
||||
COMPAT_API_METHODS = (
|
||||
"run_direct_tts",
|
||||
"get_scheduler_state",
|
||||
)
|
||||
|
||||
def run_direct_tts(self, req: dict) -> DirectTTSExecution:
|
||||
return self.api_facade.run_direct_tts(req)
|
||||
|
||||
def get_scheduler_state(self) -> dict:
|
||||
return self.runtime_facade.get_scheduler_state()
|
||||
Loading…
x
Reference in New Issue
Block a user