mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-15 22:58:13 +08:00
Introduce new modules including unified_engine_component_models, unified_engine_component_policy, unified_engine_component_registry, unified_engine_component_runtime, unified_engine_worker_completion, and unified_engine_worker_decode. These additions enhance the TTS framework by providing structured models for request handling, engine policies, and worker execution, significantly improving the architecture and maintainability of the system. The new components support asynchronous operations and optimize overall performance through better state management and processing capabilities.
336 lines
18 KiB
Python
336 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import threading
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import EngineStatus
|
|
|
|
|
|
@dataclass
|
|
class EnginePolicyConfig:
|
|
enabled: bool = True
|
|
poll_wait_ms: float = 5.0
|
|
decode_backlog_soft_max: int = 0
|
|
finalize_pending_soft_max: int = 0
|
|
prepare_inflight_soft_max: int = 0
|
|
active_decode_soft_max: int = 0
|
|
ready_for_prefill_soft_max: int = 0
|
|
active_request_soft_max: int = 0
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"enabled": bool(self.enabled),
|
|
"poll_wait_ms": float(self.poll_wait_ms),
|
|
"decode_backlog_soft_max": int(self.decode_backlog_soft_max),
|
|
"finalize_pending_soft_max": int(self.finalize_pending_soft_max),
|
|
"prepare_inflight_soft_max": int(self.prepare_inflight_soft_max),
|
|
"active_decode_soft_max": int(self.active_decode_soft_max),
|
|
"ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max),
|
|
"active_request_soft_max": int(self.active_request_soft_max),
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class EngineArbiterConfig:
|
|
poll_wait_ms: float = 5.0
|
|
decode_burst: int = 4
|
|
prepare_aging_ms: float = 10.0
|
|
finalize_aging_ms: float = 10.0
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"poll_wait_ms": float(self.poll_wait_ms),
|
|
"decode_burst": int(self.decode_burst),
|
|
"prepare_aging_ms": float(self.prepare_aging_ms),
|
|
"finalize_aging_ms": float(self.finalize_aging_ms),
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class EngineArbiterState:
|
|
total_ticks: int = 0
|
|
total_idle_ticks: int = 0
|
|
total_prepare_dispatches: int = 0
|
|
total_decode_dispatches: int = 0
|
|
total_decode_runtime_ticks: int = 0
|
|
total_finalize_dispatches: int = 0
|
|
decode_budget_remaining: int = 0
|
|
last_stage: str = "idle"
|
|
last_reason: str = "init"
|
|
last_observed_at: float = 0.0
|
|
last_policy_allowed: bool = True
|
|
|
|
|
|
class EnginePolicyArbiterController:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
policy_config: EnginePolicyConfig,
|
|
arbiter_config: EngineArbiterConfig,
|
|
snapshot_request_registry: Callable[[], Dict[str, Any]],
|
|
get_worker_state: Callable[[], Dict[str, Any]],
|
|
snapshot_prepare_state: Callable[[], Dict[str, Any]],
|
|
snapshot_finalize_state: Callable[[], Dict[str, Any]],
|
|
snapshot_dispatch_state: Callable[[], Dict[str, Any]],
|
|
snapshot_decode_runtime_state: Callable[[], Dict[str, Any]],
|
|
snapshot_job_registry: Callable[[], Dict[str, Any]],
|
|
peek_queue_age_ms: Callable[[str], float],
|
|
merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None],
|
|
) -> None:
|
|
self.policy_config = policy_config
|
|
self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0)
|
|
self.arbiter_config = arbiter_config
|
|
self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0)
|
|
self.condition = threading.Condition()
|
|
self.state = EngineArbiterState(
|
|
decode_budget_remaining=int(self.arbiter_config.decode_burst),
|
|
last_observed_at=time.perf_counter(),
|
|
)
|
|
self.snapshot_request_registry = snapshot_request_registry
|
|
self.get_worker_state = get_worker_state
|
|
self.snapshot_prepare_state = snapshot_prepare_state
|
|
self.snapshot_finalize_state = snapshot_finalize_state
|
|
self.snapshot_dispatch_state = snapshot_dispatch_state
|
|
self.snapshot_decode_runtime_state = snapshot_decode_runtime_state
|
|
self.snapshot_job_registry = snapshot_job_registry
|
|
self.peek_queue_age_ms = peek_queue_age_ms
|
|
self.merge_request_state_profile = merge_request_state_profile
|
|
|
|
def snapshot_state(self) -> Dict[str, Any]:
|
|
with self.condition:
|
|
return {
|
|
"config": self.arbiter_config.to_dict(),
|
|
"total_ticks": int(self.state.total_ticks),
|
|
"total_idle_ticks": int(self.state.total_idle_ticks),
|
|
"total_prepare_dispatches": int(self.state.total_prepare_dispatches),
|
|
"total_decode_dispatches": int(self.state.total_decode_dispatches),
|
|
"total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks),
|
|
"total_finalize_dispatches": int(self.state.total_finalize_dispatches),
|
|
"decode_budget_remaining": int(self.state.decode_budget_remaining),
|
|
"last_stage": str(self.state.last_stage),
|
|
"last_reason": str(self.state.last_reason),
|
|
"last_policy_allowed": bool(self.state.last_policy_allowed),
|
|
"last_observed_at": float(self.state.last_observed_at),
|
|
}
|
|
|
|
def notify(self) -> None:
|
|
with self.condition:
|
|
self.condition.notify_all()
|
|
|
|
def wait(self) -> None:
|
|
with self.condition:
|
|
self.condition.wait(timeout=self.arbiter_poll_s)
|
|
|
|
def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
|
|
with self.condition:
|
|
self.state.total_ticks += 1
|
|
if stage == "idle":
|
|
self.state.total_idle_ticks += 1
|
|
elif stage == "prepare":
|
|
self.state.total_prepare_dispatches += 1
|
|
self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst)
|
|
elif stage == "finalize":
|
|
self.state.total_finalize_dispatches += 1
|
|
self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst)
|
|
elif stage == "decode_dispatch":
|
|
self.state.total_decode_dispatches += 1
|
|
elif stage == "decode_runtime":
|
|
self.state.total_decode_runtime_ticks += 1
|
|
self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1)
|
|
self.state.last_stage = str(stage)
|
|
self.state.last_reason = str(reason)
|
|
self.state.last_policy_allowed = bool(policy_allowed)
|
|
self.state.last_observed_at = time.perf_counter()
|
|
|
|
def build_stage_counters(
|
|
self,
|
|
request_registry: Dict[str, Any],
|
|
worker_state: Dict[str, Any],
|
|
) -> Dict[str, Any]:
|
|
prepare_dispatcher_state = self.snapshot_prepare_state()
|
|
finalize_dispatcher_state = self.snapshot_finalize_state()
|
|
dispatcher_state = self.snapshot_dispatch_state()
|
|
active_requests = list(request_registry.get("active_requests", []))
|
|
status_counts: Dict[str, int] = {}
|
|
for item in active_requests:
|
|
status = str(item.get("status", "UNKNOWN"))
|
|
status_counts[status] = status_counts.get(status, 0) + 1
|
|
|
|
worker_pending_jobs = int(worker_state.get("pending_jobs", 0))
|
|
worker_decode_active_size = int(worker_state.get("running_requests", 0))
|
|
worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0))
|
|
worker_finalize_pending = int(worker_state.get("finalize_pending", 0))
|
|
worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0))
|
|
engine_decode_runtime_state = self.snapshot_decode_runtime_state()
|
|
engine_job_registry = self.snapshot_job_registry()
|
|
decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0))
|
|
decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0))
|
|
return {
|
|
"active_request_count": int(len(active_requests)),
|
|
"status_counts": status_counts,
|
|
"queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)),
|
|
"cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)),
|
|
"gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)),
|
|
"ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)),
|
|
"active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)),
|
|
"ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)),
|
|
"finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)),
|
|
"streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)),
|
|
"worker_pending_jobs": worker_pending_jobs,
|
|
"worker_decode_active_size": worker_decode_active_size,
|
|
"worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)),
|
|
"worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)),
|
|
"engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs,
|
|
"engine_decode_runtime_active_request_count": decode_runtime_active_size,
|
|
"engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)),
|
|
"engine_job_registry_count": int(engine_job_registry.get("job_count", 0)),
|
|
"worker_prepare_inflight": worker_prepare_inflight,
|
|
"worker_finalize_pending": worker_finalize_pending,
|
|
"worker_finalize_inflight": worker_finalize_inflight,
|
|
"engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)),
|
|
"engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)),
|
|
"engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)),
|
|
"decode_backlog": int(
|
|
decode_runtime_pending_jobs + decode_runtime_active_size
|
|
if bool(worker_state.get("engine_decode_control_enabled", False))
|
|
else worker_pending_jobs + worker_decode_active_size
|
|
),
|
|
}
|
|
|
|
def build_policy_snapshot(
|
|
self,
|
|
request_registry: Dict[str, Any],
|
|
worker_state: Dict[str, Any],
|
|
) -> Dict[str, Any]:
|
|
counters = self.build_stage_counters(request_registry, worker_state)
|
|
config = self.policy_config.to_dict()
|
|
blocked_reasons: List[Dict[str, Any]] = []
|
|
finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0))
|
|
limit_checks = [
|
|
("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])),
|
|
("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])),
|
|
("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])),
|
|
("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])),
|
|
("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])),
|
|
("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])),
|
|
]
|
|
if bool(config["enabled"]):
|
|
for name, value, limit in limit_checks:
|
|
if limit > 0 and int(value) >= int(limit):
|
|
blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)})
|
|
return {
|
|
"enabled": bool(config["enabled"]),
|
|
"allowed": (not bool(config["enabled"])) or not blocked_reasons,
|
|
"blocked_reasons": blocked_reasons,
|
|
"config": config,
|
|
"metrics": {
|
|
"active_request_count": int(counters["active_request_count"]),
|
|
"queued_request_count": int(counters["queued_request_count"]),
|
|
"ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]),
|
|
"active_decode_request_count": int(counters["active_decode_request_count"]),
|
|
"engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]),
|
|
"engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]),
|
|
"decode_backlog": int(counters["decode_backlog"]),
|
|
"prepare_inflight": int(counters["worker_prepare_inflight"]),
|
|
"finalize_pending": int(finalize_pending_total),
|
|
"engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)),
|
|
"finalize_inflight": int(counters["worker_finalize_inflight"]),
|
|
},
|
|
"observed_at": time.perf_counter(),
|
|
}
|
|
|
|
async def wait_for_policy_admission(
|
|
self,
|
|
*,
|
|
request_id: str | None,
|
|
timeout_sec: float | None,
|
|
) -> tuple[float, Dict[str, Any]]:
|
|
request_registry = self.snapshot_request_registry()
|
|
worker_state = self.get_worker_state()
|
|
snapshot = self.build_policy_snapshot(request_registry, worker_state)
|
|
if not self.policy_config.enabled:
|
|
return 0.0, snapshot
|
|
start = time.perf_counter()
|
|
deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec)))
|
|
while True:
|
|
request_registry = self.snapshot_request_registry()
|
|
worker_state = self.get_worker_state()
|
|
snapshot = self.build_policy_snapshot(request_registry, worker_state)
|
|
if snapshot["allowed"]:
|
|
wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0)
|
|
if request_id not in [None, ""]:
|
|
self.merge_request_state_profile(
|
|
str(request_id),
|
|
{
|
|
"engine_policy_wait_ms": float(wait_ms),
|
|
"engine_policy_snapshot": snapshot,
|
|
},
|
|
)
|
|
return wait_ms, snapshot
|
|
now = time.perf_counter()
|
|
if deadline is not None and now >= deadline:
|
|
blocked_summary = ", ".join(
|
|
f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", [])
|
|
)
|
|
raise TimeoutError(f"engine policy admission timeout ({blocked_summary})")
|
|
await asyncio.sleep(self.policy_poll_s)
|
|
|
|
def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
|
|
request_registry = self.snapshot_request_registry()
|
|
worker_state = self.get_worker_state()
|
|
policy_snapshot = self.build_policy_snapshot(request_registry, worker_state)
|
|
prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0))
|
|
finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0))
|
|
decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0))
|
|
decode_runtime_state = self.snapshot_decode_runtime_state()
|
|
worker_decode_has_work = bool(decode_runtime_state.get("has_work", False))
|
|
worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False))
|
|
worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0))
|
|
worker_running_requests = int(decode_runtime_state.get("active_request_count", 0))
|
|
prepare_age_ms = float(self.peek_queue_age_ms("prepare"))
|
|
finalize_age_ms = float(self.peek_queue_age_ms("finalize"))
|
|
decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending"))
|
|
decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0))
|
|
policy_allowed = bool(policy_snapshot.get("allowed", True))
|
|
if (
|
|
worker_decode_control_enabled
|
|
and worker_decode_has_work
|
|
and policy_allowed
|
|
and decode_budget_remaining > 0
|
|
and (worker_running_requests > 0 or worker_pending_jobs > 0)
|
|
):
|
|
return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state
|
|
if (
|
|
worker_decode_control_enabled
|
|
and worker_pending_jobs > 0
|
|
and policy_allowed
|
|
and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms)
|
|
):
|
|
return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state
|
|
if (
|
|
decode_waiting > 0
|
|
and policy_allowed
|
|
and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0)
|
|
):
|
|
return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state
|
|
if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0):
|
|
return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
|
|
if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms):
|
|
return "finalize", "finalize_aging", policy_snapshot, worker_state
|
|
if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0):
|
|
return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
|
|
if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms):
|
|
return "prepare", "prepare_aging", policy_snapshot, worker_state
|
|
if worker_decode_control_enabled and worker_decode_has_work and policy_allowed:
|
|
return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state
|
|
if decode_waiting > 0 and policy_allowed:
|
|
return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state
|
|
if finalize_waiting > 0:
|
|
return "finalize", "finalize_fallback", policy_snapshot, worker_state
|
|
if prepare_waiting > 0:
|
|
return "prepare", "prepare_fallback", policy_snapshot, worker_state
|
|
return "idle", "no_pending_work", policy_snapshot, worker_state
|