mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-13 05:18:12 +08:00
Introduce multiple new modules including unified_engine_api, unified_engine_audio, unified_engine_bridge, unified_engine_builder, unified_engine_components, unified_engine_delegates, and unified_engine_runtime. These additions provide a comprehensive framework for managing TTS requests, audio packing, and engine state management, significantly improving the architecture and maintainability of the TTS system. The new structure supports asynchronous operations and enhances overall performance through better request handling and processing capabilities.
199 lines
8.6 KiB
Python
199 lines
8.6 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import signal
|
|
import sys
|
|
from typing import Any, Dict, Optional
|
|
|
|
|
|
class EngineRuntimeFacade:
|
|
def __init__(self, owner: Any) -> None:
|
|
self.owner = owner
|
|
|
|
@property
|
|
def tts(self):
|
|
return self.owner.tts
|
|
|
|
@property
|
|
def reference_registry(self):
|
|
return self.owner.reference_registry
|
|
|
|
@property
|
|
def model_registry(self):
|
|
return self.owner.model_registry
|
|
|
|
@property
|
|
def scheduler_worker(self):
|
|
return self.owner.scheduler_worker
|
|
|
|
@property
|
|
def engine_decode_runtime_owner(self):
|
|
return self.owner.engine_decode_runtime_owner
|
|
|
|
@property
|
|
def engine_policy_arbiter(self):
|
|
return self.owner.engine_policy_arbiter
|
|
|
|
@property
|
|
def management_lock(self):
|
|
return self.owner.management_lock
|
|
|
|
@property
|
|
def direct_tts_lock(self):
|
|
return self.owner.direct_tts_lock
|
|
|
|
@property
|
|
def control_callbacks(self):
|
|
return self.owner.control_callbacks
|
|
|
|
@staticmethod
|
|
def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None:
|
|
if component is None or not hasattr(component, "snapshot"):
|
|
return None
|
|
try:
|
|
return dict(component.snapshot())
|
|
except Exception:
|
|
return None
|
|
|
|
def _build_stage_counters(
|
|
self,
|
|
request_registry: Dict[str, Any],
|
|
worker_state: Dict[str, Any],
|
|
) -> Dict[str, Any]:
|
|
return self.engine_policy_arbiter.build_stage_counters(request_registry, worker_state)
|
|
|
|
def _build_engine_policy_snapshot(
|
|
self,
|
|
request_registry: Dict[str, Any],
|
|
worker_state: Dict[str, Any],
|
|
) -> Dict[str, Any]:
|
|
return self.engine_policy_arbiter.build_policy_snapshot(request_registry, worker_state)
|
|
|
|
def _build_stage_summary(
|
|
self,
|
|
request_registry: Dict[str, Any],
|
|
worker_state: Dict[str, Any],
|
|
) -> Dict[str, Any]:
|
|
counters = self._build_stage_counters(request_registry, worker_state)
|
|
bert_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_bert_batch_worker", None))
|
|
ref_semantic_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_ref_semantic_batch_worker", None))
|
|
text_preprocessor_state = self._safe_component_snapshot(getattr(self.tts, "text_preprocessor", None))
|
|
|
|
return {
|
|
**counters,
|
|
"engine_drained": bool(self.owner._is_engine_drained()),
|
|
"admission_config": {
|
|
"decode_backlog_max": int(worker_state.get("decode_backlog_max", 0)),
|
|
"finalize_pending_max": int(worker_state.get("finalize_pending_max", 0)),
|
|
},
|
|
"engine_policy": self._build_engine_policy_snapshot(request_registry, worker_state),
|
|
"engine_arbiter_state": self.owner._snapshot_engine_arbiter_state(),
|
|
"engine_decode_runtime_state": self.owner._snapshot_engine_decode_runtime_state(),
|
|
"engine_job_registry": self.owner._snapshot_engine_job_registry(),
|
|
"engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(),
|
|
"engine_prepare_state": self.owner._snapshot_engine_prepare_state(),
|
|
"engine_finalize_state": self.owner._snapshot_engine_finalize_state(),
|
|
"engine_dispatcher_state": self.owner._snapshot_engine_dispatch_state(),
|
|
"active_batch": dict(worker_state.get("active_batch") or {}),
|
|
"prepare_state": dict(worker_state.get("prepare_state") or {}),
|
|
"bert_batch_worker_state": bert_worker_state,
|
|
"ref_semantic_worker_state": ref_semantic_worker_state,
|
|
"text_preprocessor_state": text_preprocessor_state,
|
|
}
|
|
|
|
def get_scheduler_state(self) -> dict:
|
|
return self.scheduler_worker.snapshot()
|
|
|
|
def get_runtime_state(self) -> dict:
|
|
model_state = self.model_registry.snapshot()
|
|
default_ref = self.reference_registry.get_default()
|
|
scheduler_state = self.get_scheduler_state()
|
|
request_registry = self.owner._snapshot_request_registry()
|
|
engine_policy = self._build_engine_policy_snapshot(request_registry, scheduler_state)
|
|
engine_arbiter_state = self.owner._snapshot_engine_arbiter_state()
|
|
engine_decode_runtime_state = self.owner._snapshot_engine_decode_runtime_state()
|
|
engine_job_registry = self.owner._snapshot_engine_job_registry()
|
|
engine_prepare_state = self.owner._snapshot_engine_prepare_state()
|
|
engine_finalize_state = self.owner._snapshot_engine_finalize_state()
|
|
engine_dispatcher_state = self.owner._snapshot_engine_dispatch_state()
|
|
engine_drained = self.owner._is_engine_drained()
|
|
return {
|
|
"message": "success",
|
|
"default_reference": {
|
|
"ref_audio_path": default_ref.ref_audio_path,
|
|
"updated_at": default_ref.updated_at,
|
|
},
|
|
"model_registry": {
|
|
"generation": model_state.generation,
|
|
"t2s_generation": model_state.t2s_generation,
|
|
"vits_generation": model_state.vits_generation,
|
|
"t2s_weights_path": model_state.t2s_weights_path,
|
|
"vits_weights_path": model_state.vits_weights_path,
|
|
"updated_at": model_state.updated_at,
|
|
},
|
|
"worker_state": scheduler_state,
|
|
"engine_policy": engine_policy,
|
|
"engine_arbiter_state": engine_arbiter_state,
|
|
"engine_decode_runtime_state": engine_decode_runtime_state,
|
|
"engine_job_registry": engine_job_registry,
|
|
"engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(),
|
|
"engine_prepare_state": engine_prepare_state,
|
|
"engine_finalize_state": engine_finalize_state,
|
|
"engine_dispatcher_state": engine_dispatcher_state,
|
|
"engine_drained": bool(engine_drained),
|
|
"request_registry": request_registry,
|
|
"stage_summary": self._build_stage_summary(request_registry, scheduler_state),
|
|
}
|
|
|
|
def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None:
|
|
if not self.scheduler_worker.wait_until_idle(timeout_sec=timeout_sec):
|
|
raise TimeoutError("scheduler worker did not drain before model reload")
|
|
|
|
def set_refer_audio(self, refer_audio_path: str | None) -> dict:
|
|
if refer_audio_path in [None, ""]:
|
|
state = self.reference_registry.clear()
|
|
return {"message": "success", "default_ref_audio_path": state.ref_audio_path}
|
|
if not os.path.exists(str(refer_audio_path)):
|
|
raise FileNotFoundError(f"{refer_audio_path} not exists")
|
|
with self.management_lock:
|
|
with self.direct_tts_lock:
|
|
self.tts.set_ref_audio(str(refer_audio_path))
|
|
state = self.reference_registry.set_default(str(refer_audio_path))
|
|
return {"message": "success", "default_ref_audio_path": state.ref_audio_path}
|
|
|
|
def set_gpt_weights(self, weights_path: str) -> dict:
|
|
if weights_path in ["", None]:
|
|
raise ValueError("gpt weight path is required")
|
|
with self.management_lock:
|
|
self._wait_for_safe_reload()
|
|
with self.direct_tts_lock:
|
|
self.tts.init_t2s_weights(weights_path)
|
|
self.tts.refresh_runtime_components()
|
|
state = self.model_registry.mark_t2s_reload(str(weights_path))
|
|
return {"message": "success", "t2s_generation": state.t2s_generation, "generation": state.generation}
|
|
|
|
def set_sovits_weights(self, weights_path: str) -> dict:
|
|
if weights_path in ["", None]:
|
|
raise ValueError("sovits weight path is required")
|
|
with self.management_lock:
|
|
self._wait_for_safe_reload()
|
|
with self.direct_tts_lock:
|
|
self.tts.init_vits_weights(weights_path)
|
|
self.tts.refresh_runtime_components()
|
|
state = self.model_registry.mark_vits_reload(str(weights_path))
|
|
return {"message": "success", "vits_generation": state.vits_generation, "generation": state.generation}
|
|
|
|
def handle_control(self, command: str) -> None:
|
|
if command == "restart":
|
|
if self.control_callbacks.restart is None:
|
|
os.execl(sys.executable, sys.executable, *sys.argv)
|
|
self.control_callbacks.restart()
|
|
return
|
|
if command == "exit":
|
|
if self.control_callbacks.exit is None:
|
|
os.kill(os.getpid(), signal.SIGTERM)
|
|
return
|
|
self.control_callbacks.exit()
|
|
return
|
|
raise ValueError(f"unsupported command: {command}")
|