GPT-SoVITS/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py
baicai-1145 a3a5aad157 Add unified engine components for TTS processing and state management
Introduce new modules including unified_engine_component_models, unified_engine_component_policy, unified_engine_component_registry, unified_engine_component_runtime, unified_engine_worker_completion, and unified_engine_worker_decode. These additions enhance the TTS framework by providing structured models for request handling, engine policies, and worker execution, significantly improving the architecture and maintainability of the system. The new components support asynchronous operations and optimize overall performance through better state management and processing capabilities.
2026-03-11 20:49:41 +08:00

336 lines
18 KiB
Python

from __future__ import annotations
import asyncio
import threading
import time
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import EngineStatus
@dataclass
class EnginePolicyConfig:
enabled: bool = True
poll_wait_ms: float = 5.0
decode_backlog_soft_max: int = 0
finalize_pending_soft_max: int = 0
prepare_inflight_soft_max: int = 0
active_decode_soft_max: int = 0
ready_for_prefill_soft_max: int = 0
active_request_soft_max: int = 0
def to_dict(self) -> Dict[str, Any]:
return {
"enabled": bool(self.enabled),
"poll_wait_ms": float(self.poll_wait_ms),
"decode_backlog_soft_max": int(self.decode_backlog_soft_max),
"finalize_pending_soft_max": int(self.finalize_pending_soft_max),
"prepare_inflight_soft_max": int(self.prepare_inflight_soft_max),
"active_decode_soft_max": int(self.active_decode_soft_max),
"ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max),
"active_request_soft_max": int(self.active_request_soft_max),
}
@dataclass
class EngineArbiterConfig:
poll_wait_ms: float = 5.0
decode_burst: int = 4
prepare_aging_ms: float = 10.0
finalize_aging_ms: float = 10.0
def to_dict(self) -> Dict[str, Any]:
return {
"poll_wait_ms": float(self.poll_wait_ms),
"decode_burst": int(self.decode_burst),
"prepare_aging_ms": float(self.prepare_aging_ms),
"finalize_aging_ms": float(self.finalize_aging_ms),
}
@dataclass
class EngineArbiterState:
total_ticks: int = 0
total_idle_ticks: int = 0
total_prepare_dispatches: int = 0
total_decode_dispatches: int = 0
total_decode_runtime_ticks: int = 0
total_finalize_dispatches: int = 0
decode_budget_remaining: int = 0
last_stage: str = "idle"
last_reason: str = "init"
last_observed_at: float = 0.0
last_policy_allowed: bool = True
class EnginePolicyArbiterController:
def __init__(
self,
*,
policy_config: EnginePolicyConfig,
arbiter_config: EngineArbiterConfig,
snapshot_request_registry: Callable[[], Dict[str, Any]],
get_worker_state: Callable[[], Dict[str, Any]],
snapshot_prepare_state: Callable[[], Dict[str, Any]],
snapshot_finalize_state: Callable[[], Dict[str, Any]],
snapshot_dispatch_state: Callable[[], Dict[str, Any]],
snapshot_decode_runtime_state: Callable[[], Dict[str, Any]],
snapshot_job_registry: Callable[[], Dict[str, Any]],
peek_queue_age_ms: Callable[[str], float],
merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None],
) -> None:
self.policy_config = policy_config
self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0)
self.arbiter_config = arbiter_config
self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0)
self.condition = threading.Condition()
self.state = EngineArbiterState(
decode_budget_remaining=int(self.arbiter_config.decode_burst),
last_observed_at=time.perf_counter(),
)
self.snapshot_request_registry = snapshot_request_registry
self.get_worker_state = get_worker_state
self.snapshot_prepare_state = snapshot_prepare_state
self.snapshot_finalize_state = snapshot_finalize_state
self.snapshot_dispatch_state = snapshot_dispatch_state
self.snapshot_decode_runtime_state = snapshot_decode_runtime_state
self.snapshot_job_registry = snapshot_job_registry
self.peek_queue_age_ms = peek_queue_age_ms
self.merge_request_state_profile = merge_request_state_profile
def snapshot_state(self) -> Dict[str, Any]:
with self.condition:
return {
"config": self.arbiter_config.to_dict(),
"total_ticks": int(self.state.total_ticks),
"total_idle_ticks": int(self.state.total_idle_ticks),
"total_prepare_dispatches": int(self.state.total_prepare_dispatches),
"total_decode_dispatches": int(self.state.total_decode_dispatches),
"total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks),
"total_finalize_dispatches": int(self.state.total_finalize_dispatches),
"decode_budget_remaining": int(self.state.decode_budget_remaining),
"last_stage": str(self.state.last_stage),
"last_reason": str(self.state.last_reason),
"last_policy_allowed": bool(self.state.last_policy_allowed),
"last_observed_at": float(self.state.last_observed_at),
}
def notify(self) -> None:
with self.condition:
self.condition.notify_all()
def wait(self) -> None:
with self.condition:
self.condition.wait(timeout=self.arbiter_poll_s)
def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
with self.condition:
self.state.total_ticks += 1
if stage == "idle":
self.state.total_idle_ticks += 1
elif stage == "prepare":
self.state.total_prepare_dispatches += 1
self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst)
elif stage == "finalize":
self.state.total_finalize_dispatches += 1
self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst)
elif stage == "decode_dispatch":
self.state.total_decode_dispatches += 1
elif stage == "decode_runtime":
self.state.total_decode_runtime_ticks += 1
self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1)
self.state.last_stage = str(stage)
self.state.last_reason = str(reason)
self.state.last_policy_allowed = bool(policy_allowed)
self.state.last_observed_at = time.perf_counter()
def build_stage_counters(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
prepare_dispatcher_state = self.snapshot_prepare_state()
finalize_dispatcher_state = self.snapshot_finalize_state()
dispatcher_state = self.snapshot_dispatch_state()
active_requests = list(request_registry.get("active_requests", []))
status_counts: Dict[str, int] = {}
for item in active_requests:
status = str(item.get("status", "UNKNOWN"))
status_counts[status] = status_counts.get(status, 0) + 1
worker_pending_jobs = int(worker_state.get("pending_jobs", 0))
worker_decode_active_size = int(worker_state.get("running_requests", 0))
worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0))
worker_finalize_pending = int(worker_state.get("finalize_pending", 0))
worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0))
engine_decode_runtime_state = self.snapshot_decode_runtime_state()
engine_job_registry = self.snapshot_job_registry()
decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0))
decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0))
return {
"active_request_count": int(len(active_requests)),
"status_counts": status_counts,
"queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)),
"cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)),
"gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)),
"ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)),
"active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)),
"ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)),
"finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)),
"streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)),
"worker_pending_jobs": worker_pending_jobs,
"worker_decode_active_size": worker_decode_active_size,
"worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)),
"worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)),
"engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs,
"engine_decode_runtime_active_request_count": decode_runtime_active_size,
"engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)),
"engine_job_registry_count": int(engine_job_registry.get("job_count", 0)),
"worker_prepare_inflight": worker_prepare_inflight,
"worker_finalize_pending": worker_finalize_pending,
"worker_finalize_inflight": worker_finalize_inflight,
"engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)),
"engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)),
"engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)),
"decode_backlog": int(
decode_runtime_pending_jobs + decode_runtime_active_size
if bool(worker_state.get("engine_decode_control_enabled", False))
else worker_pending_jobs + worker_decode_active_size
),
}
def build_policy_snapshot(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
counters = self.build_stage_counters(request_registry, worker_state)
config = self.policy_config.to_dict()
blocked_reasons: List[Dict[str, Any]] = []
finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0))
limit_checks = [
("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])),
("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])),
("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])),
("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])),
("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])),
("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])),
]
if bool(config["enabled"]):
for name, value, limit in limit_checks:
if limit > 0 and int(value) >= int(limit):
blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)})
return {
"enabled": bool(config["enabled"]),
"allowed": (not bool(config["enabled"])) or not blocked_reasons,
"blocked_reasons": blocked_reasons,
"config": config,
"metrics": {
"active_request_count": int(counters["active_request_count"]),
"queued_request_count": int(counters["queued_request_count"]),
"ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]),
"active_decode_request_count": int(counters["active_decode_request_count"]),
"engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]),
"engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]),
"decode_backlog": int(counters["decode_backlog"]),
"prepare_inflight": int(counters["worker_prepare_inflight"]),
"finalize_pending": int(finalize_pending_total),
"engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)),
"finalize_inflight": int(counters["worker_finalize_inflight"]),
},
"observed_at": time.perf_counter(),
}
async def wait_for_policy_admission(
self,
*,
request_id: str | None,
timeout_sec: float | None,
) -> tuple[float, Dict[str, Any]]:
request_registry = self.snapshot_request_registry()
worker_state = self.get_worker_state()
snapshot = self.build_policy_snapshot(request_registry, worker_state)
if not self.policy_config.enabled:
return 0.0, snapshot
start = time.perf_counter()
deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec)))
while True:
request_registry = self.snapshot_request_registry()
worker_state = self.get_worker_state()
snapshot = self.build_policy_snapshot(request_registry, worker_state)
if snapshot["allowed"]:
wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0)
if request_id not in [None, ""]:
self.merge_request_state_profile(
str(request_id),
{
"engine_policy_wait_ms": float(wait_ms),
"engine_policy_snapshot": snapshot,
},
)
return wait_ms, snapshot
now = time.perf_counter()
if deadline is not None and now >= deadline:
blocked_summary = ", ".join(
f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", [])
)
raise TimeoutError(f"engine policy admission timeout ({blocked_summary})")
await asyncio.sleep(self.policy_poll_s)
def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
request_registry = self.snapshot_request_registry()
worker_state = self.get_worker_state()
policy_snapshot = self.build_policy_snapshot(request_registry, worker_state)
prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0))
finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0))
decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0))
decode_runtime_state = self.snapshot_decode_runtime_state()
worker_decode_has_work = bool(decode_runtime_state.get("has_work", False))
worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False))
worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0))
worker_running_requests = int(decode_runtime_state.get("active_request_count", 0))
prepare_age_ms = float(self.peek_queue_age_ms("prepare"))
finalize_age_ms = float(self.peek_queue_age_ms("finalize"))
decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending"))
decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0))
policy_allowed = bool(policy_snapshot.get("allowed", True))
if (
worker_decode_control_enabled
and worker_decode_has_work
and policy_allowed
and decode_budget_remaining > 0
and (worker_running_requests > 0 or worker_pending_jobs > 0)
):
return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state
if (
worker_decode_control_enabled
and worker_pending_jobs > 0
and policy_allowed
and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms)
):
return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state
if (
decode_waiting > 0
and policy_allowed
and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0)
):
return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state
if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0):
return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms):
return "finalize", "finalize_aging", policy_snapshot, worker_state
if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0):
return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms):
return "prepare", "prepare_aging", policy_snapshot, worker_state
if worker_decode_control_enabled and worker_decode_has_work and policy_allowed:
return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state
if decode_waiting > 0 and policy_allowed:
return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state
if finalize_waiting > 0:
return "finalize", "finalize_fallback", policy_snapshot, worker_state
if prepare_waiting > 0:
return "prepare", "prepare_fallback", policy_snapshot, worker_state
return "idle", "no_pending_work", policy_snapshot, worker_state