Add unified engine components for TTS processing and state management

Introduce new modules including unified_engine_component_models, unified_engine_component_policy, unified_engine_component_registry, unified_engine_component_runtime, unified_engine_worker_completion, and unified_engine_worker_decode. These additions enhance the TTS framework by providing structured models for request handling, engine policies, and worker execution, significantly improving the architecture and maintainability of the system. The new components support asynchronous operations and optimize overall performance through better state management and processing capabilities.
This commit is contained in:
baicai-1145 2026-03-11 20:49:41 +08:00
parent 3fd4f48651
commit a3a5aad157
13 changed files with 2772 additions and 2605 deletions

View File

@ -0,0 +1,120 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec
@dataclass
class RuntimeControlCallbacks:
restart: Callable[[], None] | None = None
exit: Callable[[], None] | None = None
@dataclass
class DirectTTSExecution:
media_type: str
streaming: bool
audio_generator: Optional[Generator[bytes, None, None]] = None
audio_bytes: Optional[bytes] = None
request_id: Optional[str] = None
@dataclass
class NormalizedEngineRequest:
request_id: str
text: str
text_lang: str
ref_audio_path: str
prompt_lang: str
prompt_text: str = ""
aux_ref_audio_paths: List[str] | None = None
top_k: int = 15
top_p: float = 1.0
temperature: float = 1.0
repetition_penalty: float = 1.35
early_stop_num: int = -1
ready_step: int = 0
text_split_method: str = "cut5"
batch_size: int = 1
batch_threshold: float = 0.75
split_bucket: bool = False
speed_factor: float = 1.0
fragment_interval: float = 0.3
seed: int = -1
media_type: str = "wav"
streaming_mode: bool | int = False
return_fragment: bool = False
fixed_length_chunk: bool = False
response_streaming: bool = False
parallel_infer: bool = False
sample_steps: int = 32
super_sampling: bool = False
overlap_length: int = 2
min_chunk_length: int = 16
timeout_sec: float | None = None
def to_payload(self) -> Dict[str, Any]:
return {
"request_id": self.request_id,
"text": self.text,
"text_lang": self.text_lang,
"ref_audio_path": self.ref_audio_path,
"aux_ref_audio_paths": list(self.aux_ref_audio_paths) if self.aux_ref_audio_paths else None,
"prompt_text": self.prompt_text,
"prompt_lang": self.prompt_lang,
"top_k": self.top_k,
"top_p": self.top_p,
"temperature": self.temperature,
"text_split_method": self.text_split_method,
"batch_size": self.batch_size,
"batch_threshold": self.batch_threshold,
"speed_factor": self.speed_factor,
"split_bucket": self.split_bucket,
"fragment_interval": self.fragment_interval,
"seed": self.seed,
"media_type": self.media_type,
"streaming_mode": self.streaming_mode,
"return_fragment": self.return_fragment,
"fixed_length_chunk": self.fixed_length_chunk,
"response_streaming": self.response_streaming,
"parallel_infer": self.parallel_infer,
"repetition_penalty": self.repetition_penalty,
"sample_steps": self.sample_steps,
"super_sampling": self.super_sampling,
"overlap_length": self.overlap_length,
"min_chunk_length": self.min_chunk_length,
"early_stop_num": self.early_stop_num,
"ready_step": self.ready_step,
"timeout_sec": self.timeout_sec,
}
def to_scheduler_spec(self) -> SchedulerRequestSpec:
return SchedulerRequestSpec(
request_id=self.request_id,
ref_audio_path=Path(self.ref_audio_path),
prompt_text=self.prompt_text,
prompt_lang=self.prompt_lang,
text=self.text,
text_lang=self.text_lang,
top_k=self.top_k,
top_p=self.top_p,
temperature=self.temperature,
repetition_penalty=self.repetition_penalty,
early_stop_num=self.early_stop_num,
ready_step=self.ready_step,
)
@dataclass
class SchedulerDebugExecution:
payload: Dict[str, Any]
@dataclass
class SchedulerSubmitExecution:
audio_bytes: bytes
media_type: str
headers: Dict[str, str]

View File

@ -0,0 +1,335 @@
from __future__ import annotations
import asyncio
import threading
import time
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import EngineStatus
@dataclass
class EnginePolicyConfig:
enabled: bool = True
poll_wait_ms: float = 5.0
decode_backlog_soft_max: int = 0
finalize_pending_soft_max: int = 0
prepare_inflight_soft_max: int = 0
active_decode_soft_max: int = 0
ready_for_prefill_soft_max: int = 0
active_request_soft_max: int = 0
def to_dict(self) -> Dict[str, Any]:
return {
"enabled": bool(self.enabled),
"poll_wait_ms": float(self.poll_wait_ms),
"decode_backlog_soft_max": int(self.decode_backlog_soft_max),
"finalize_pending_soft_max": int(self.finalize_pending_soft_max),
"prepare_inflight_soft_max": int(self.prepare_inflight_soft_max),
"active_decode_soft_max": int(self.active_decode_soft_max),
"ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max),
"active_request_soft_max": int(self.active_request_soft_max),
}
@dataclass
class EngineArbiterConfig:
poll_wait_ms: float = 5.0
decode_burst: int = 4
prepare_aging_ms: float = 10.0
finalize_aging_ms: float = 10.0
def to_dict(self) -> Dict[str, Any]:
return {
"poll_wait_ms": float(self.poll_wait_ms),
"decode_burst": int(self.decode_burst),
"prepare_aging_ms": float(self.prepare_aging_ms),
"finalize_aging_ms": float(self.finalize_aging_ms),
}
@dataclass
class EngineArbiterState:
total_ticks: int = 0
total_idle_ticks: int = 0
total_prepare_dispatches: int = 0
total_decode_dispatches: int = 0
total_decode_runtime_ticks: int = 0
total_finalize_dispatches: int = 0
decode_budget_remaining: int = 0
last_stage: str = "idle"
last_reason: str = "init"
last_observed_at: float = 0.0
last_policy_allowed: bool = True
class EnginePolicyArbiterController:
def __init__(
self,
*,
policy_config: EnginePolicyConfig,
arbiter_config: EngineArbiterConfig,
snapshot_request_registry: Callable[[], Dict[str, Any]],
get_worker_state: Callable[[], Dict[str, Any]],
snapshot_prepare_state: Callable[[], Dict[str, Any]],
snapshot_finalize_state: Callable[[], Dict[str, Any]],
snapshot_dispatch_state: Callable[[], Dict[str, Any]],
snapshot_decode_runtime_state: Callable[[], Dict[str, Any]],
snapshot_job_registry: Callable[[], Dict[str, Any]],
peek_queue_age_ms: Callable[[str], float],
merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None],
) -> None:
self.policy_config = policy_config
self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0)
self.arbiter_config = arbiter_config
self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0)
self.condition = threading.Condition()
self.state = EngineArbiterState(
decode_budget_remaining=int(self.arbiter_config.decode_burst),
last_observed_at=time.perf_counter(),
)
self.snapshot_request_registry = snapshot_request_registry
self.get_worker_state = get_worker_state
self.snapshot_prepare_state = snapshot_prepare_state
self.snapshot_finalize_state = snapshot_finalize_state
self.snapshot_dispatch_state = snapshot_dispatch_state
self.snapshot_decode_runtime_state = snapshot_decode_runtime_state
self.snapshot_job_registry = snapshot_job_registry
self.peek_queue_age_ms = peek_queue_age_ms
self.merge_request_state_profile = merge_request_state_profile
def snapshot_state(self) -> Dict[str, Any]:
with self.condition:
return {
"config": self.arbiter_config.to_dict(),
"total_ticks": int(self.state.total_ticks),
"total_idle_ticks": int(self.state.total_idle_ticks),
"total_prepare_dispatches": int(self.state.total_prepare_dispatches),
"total_decode_dispatches": int(self.state.total_decode_dispatches),
"total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks),
"total_finalize_dispatches": int(self.state.total_finalize_dispatches),
"decode_budget_remaining": int(self.state.decode_budget_remaining),
"last_stage": str(self.state.last_stage),
"last_reason": str(self.state.last_reason),
"last_policy_allowed": bool(self.state.last_policy_allowed),
"last_observed_at": float(self.state.last_observed_at),
}
def notify(self) -> None:
with self.condition:
self.condition.notify_all()
def wait(self) -> None:
with self.condition:
self.condition.wait(timeout=self.arbiter_poll_s)
def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
with self.condition:
self.state.total_ticks += 1
if stage == "idle":
self.state.total_idle_ticks += 1
elif stage == "prepare":
self.state.total_prepare_dispatches += 1
self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst)
elif stage == "finalize":
self.state.total_finalize_dispatches += 1
self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst)
elif stage == "decode_dispatch":
self.state.total_decode_dispatches += 1
elif stage == "decode_runtime":
self.state.total_decode_runtime_ticks += 1
self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1)
self.state.last_stage = str(stage)
self.state.last_reason = str(reason)
self.state.last_policy_allowed = bool(policy_allowed)
self.state.last_observed_at = time.perf_counter()
def build_stage_counters(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
prepare_dispatcher_state = self.snapshot_prepare_state()
finalize_dispatcher_state = self.snapshot_finalize_state()
dispatcher_state = self.snapshot_dispatch_state()
active_requests = list(request_registry.get("active_requests", []))
status_counts: Dict[str, int] = {}
for item in active_requests:
status = str(item.get("status", "UNKNOWN"))
status_counts[status] = status_counts.get(status, 0) + 1
worker_pending_jobs = int(worker_state.get("pending_jobs", 0))
worker_decode_active_size = int(worker_state.get("running_requests", 0))
worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0))
worker_finalize_pending = int(worker_state.get("finalize_pending", 0))
worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0))
engine_decode_runtime_state = self.snapshot_decode_runtime_state()
engine_job_registry = self.snapshot_job_registry()
decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0))
decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0))
return {
"active_request_count": int(len(active_requests)),
"status_counts": status_counts,
"queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)),
"cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)),
"gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)),
"ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)),
"active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)),
"ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)),
"finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)),
"streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)),
"worker_pending_jobs": worker_pending_jobs,
"worker_decode_active_size": worker_decode_active_size,
"worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)),
"worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)),
"engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs,
"engine_decode_runtime_active_request_count": decode_runtime_active_size,
"engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)),
"engine_job_registry_count": int(engine_job_registry.get("job_count", 0)),
"worker_prepare_inflight": worker_prepare_inflight,
"worker_finalize_pending": worker_finalize_pending,
"worker_finalize_inflight": worker_finalize_inflight,
"engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)),
"engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)),
"engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)),
"decode_backlog": int(
decode_runtime_pending_jobs + decode_runtime_active_size
if bool(worker_state.get("engine_decode_control_enabled", False))
else worker_pending_jobs + worker_decode_active_size
),
}
def build_policy_snapshot(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
counters = self.build_stage_counters(request_registry, worker_state)
config = self.policy_config.to_dict()
blocked_reasons: List[Dict[str, Any]] = []
finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0))
limit_checks = [
("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])),
("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])),
("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])),
("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])),
("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])),
("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])),
]
if bool(config["enabled"]):
for name, value, limit in limit_checks:
if limit > 0 and int(value) >= int(limit):
blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)})
return {
"enabled": bool(config["enabled"]),
"allowed": (not bool(config["enabled"])) or not blocked_reasons,
"blocked_reasons": blocked_reasons,
"config": config,
"metrics": {
"active_request_count": int(counters["active_request_count"]),
"queued_request_count": int(counters["queued_request_count"]),
"ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]),
"active_decode_request_count": int(counters["active_decode_request_count"]),
"engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]),
"engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]),
"decode_backlog": int(counters["decode_backlog"]),
"prepare_inflight": int(counters["worker_prepare_inflight"]),
"finalize_pending": int(finalize_pending_total),
"engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)),
"finalize_inflight": int(counters["worker_finalize_inflight"]),
},
"observed_at": time.perf_counter(),
}
async def wait_for_policy_admission(
self,
*,
request_id: str | None,
timeout_sec: float | None,
) -> tuple[float, Dict[str, Any]]:
request_registry = self.snapshot_request_registry()
worker_state = self.get_worker_state()
snapshot = self.build_policy_snapshot(request_registry, worker_state)
if not self.policy_config.enabled:
return 0.0, snapshot
start = time.perf_counter()
deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec)))
while True:
request_registry = self.snapshot_request_registry()
worker_state = self.get_worker_state()
snapshot = self.build_policy_snapshot(request_registry, worker_state)
if snapshot["allowed"]:
wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0)
if request_id not in [None, ""]:
self.merge_request_state_profile(
str(request_id),
{
"engine_policy_wait_ms": float(wait_ms),
"engine_policy_snapshot": snapshot,
},
)
return wait_ms, snapshot
now = time.perf_counter()
if deadline is not None and now >= deadline:
blocked_summary = ", ".join(
f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", [])
)
raise TimeoutError(f"engine policy admission timeout ({blocked_summary})")
await asyncio.sleep(self.policy_poll_s)
def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
request_registry = self.snapshot_request_registry()
worker_state = self.get_worker_state()
policy_snapshot = self.build_policy_snapshot(request_registry, worker_state)
prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0))
finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0))
decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0))
decode_runtime_state = self.snapshot_decode_runtime_state()
worker_decode_has_work = bool(decode_runtime_state.get("has_work", False))
worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False))
worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0))
worker_running_requests = int(decode_runtime_state.get("active_request_count", 0))
prepare_age_ms = float(self.peek_queue_age_ms("prepare"))
finalize_age_ms = float(self.peek_queue_age_ms("finalize"))
decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending"))
decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0))
policy_allowed = bool(policy_snapshot.get("allowed", True))
if (
worker_decode_control_enabled
and worker_decode_has_work
and policy_allowed
and decode_budget_remaining > 0
and (worker_running_requests > 0 or worker_pending_jobs > 0)
):
return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state
if (
worker_decode_control_enabled
and worker_pending_jobs > 0
and policy_allowed
and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms)
):
return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state
if (
decode_waiting > 0
and policy_allowed
and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0)
):
return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state
if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0):
return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms):
return "finalize", "finalize_aging", policy_snapshot, worker_state
if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0):
return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms):
return "prepare", "prepare_aging", policy_snapshot, worker_state
if worker_decode_control_enabled and worker_decode_has_work and policy_allowed:
return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state
if decode_waiting > 0 and policy_allowed:
return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state
if finalize_waiting > 0:
return "finalize", "finalize_fallback", policy_snapshot, worker_state
if prepare_waiting > 0:
return "prepare", "prepare_fallback", policy_snapshot, worker_state
return "idle", "no_pending_work", policy_snapshot, worker_state

View File

@ -0,0 +1,381 @@
from __future__ import annotations
import asyncio
import threading
import time
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Deque, Dict, Optional, Sequence
import numpy as np
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState
@dataclass
class DefaultReferenceState:
ref_audio_path: str | None = None
updated_at: float = 0.0
class ReferenceRegistry:
def __init__(self) -> None:
self._lock = threading.Lock()
self._state = DefaultReferenceState()
def set_default(self, ref_audio_path: str) -> DefaultReferenceState:
with self._lock:
self._state = DefaultReferenceState(ref_audio_path=str(ref_audio_path), updated_at=time.time())
return self._state
def clear(self) -> DefaultReferenceState:
with self._lock:
self._state = DefaultReferenceState()
return self._state
def get_default(self) -> DefaultReferenceState:
with self._lock:
return DefaultReferenceState(
ref_audio_path=self._state.ref_audio_path,
updated_at=self._state.updated_at,
)
@dataclass
class ModelRegistryState:
t2s_weights_path: str
vits_weights_path: str
generation: int = 0
t2s_generation: int = 0
vits_generation: int = 0
updated_at: float = field(default_factory=time.time)
class ModelRegistry:
def __init__(self, t2s_weights_path: str, vits_weights_path: str) -> None:
self._lock = threading.Lock()
self._state = ModelRegistryState(
t2s_weights_path=str(t2s_weights_path),
vits_weights_path=str(vits_weights_path),
)
def snapshot(self) -> ModelRegistryState:
with self._lock:
return ModelRegistryState(
t2s_weights_path=self._state.t2s_weights_path,
vits_weights_path=self._state.vits_weights_path,
generation=self._state.generation,
t2s_generation=self._state.t2s_generation,
vits_generation=self._state.vits_generation,
updated_at=self._state.updated_at,
)
def mark_t2s_reload(self, weights_path: str) -> ModelRegistryState:
with self._lock:
self._state.t2s_weights_path = str(weights_path)
self._state.generation += 1
self._state.t2s_generation += 1
self._state.updated_at = time.time()
return ModelRegistryState(
t2s_weights_path=self._state.t2s_weights_path,
vits_weights_path=self._state.vits_weights_path,
generation=self._state.generation,
t2s_generation=self._state.t2s_generation,
vits_generation=self._state.vits_generation,
updated_at=self._state.updated_at,
)
def mark_vits_reload(self, weights_path: str) -> ModelRegistryState:
with self._lock:
self._state.vits_weights_path = str(weights_path)
self._state.generation += 1
self._state.vits_generation += 1
self._state.updated_at = time.time()
return ModelRegistryState(
t2s_weights_path=self._state.t2s_weights_path,
vits_weights_path=self._state.vits_weights_path,
generation=self._state.generation,
t2s_generation=self._state.t2s_generation,
vits_generation=self._state.vits_generation,
updated_at=self._state.updated_at,
)
class EngineStatus:
NEW = "NEW"
QUEUED = "QUEUED"
VALIDATED = "VALIDATED"
CPU_PREPARING = "CPU_PREPARING"
GPU_PREPARING = "GPU_PREPARING"
READY_FOR_PREFILL = "READY_FOR_PREFILL"
ACTIVE_DECODE = "ACTIVE_DECODE"
READY_FOR_FINALIZE = "READY_FOR_FINALIZE"
FINALIZING = "FINALIZING"
STREAMING = "STREAMING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
@dataclass
class EngineRequestState:
request_id: str
api_mode: str
backend: str
media_type: str
response_streaming: bool
submit_ts: float
deadline_ts: float | None = None
status: str = EngineStatus.NEW
updated_ts: float = 0.0
error: str | None = None
finish_reason: str | None = None
meta: Dict[str, Any] = field(default_factory=dict)
profile: Dict[str, Any] = field(default_factory=dict)
lifecycle_timestamps: Dict[str, float] = field(default_factory=dict)
def to_summary(self) -> Dict[str, Any]:
return {
"request_id": self.request_id,
"api_mode": self.api_mode,
"backend": self.backend,
"media_type": self.media_type,
"response_streaming": self.response_streaming,
"status": self.status,
"submit_ts": self.submit_ts,
"updated_ts": self.updated_ts,
"deadline_ts": self.deadline_ts,
"error": self.error,
"finish_reason": self.finish_reason,
"meta": dict(self.meta),
"profile": dict(self.profile),
"lifecycle_timestamps": dict(self.lifecycle_timestamps),
}
class EngineRequestRegistry:
def __init__(self, recent_limit: int) -> None:
self.lock = threading.Lock()
self.active_requests: Dict[str, EngineRequestState] = {}
self.recent_requests: Deque[EngineRequestState] = deque()
self.recent_limit = max(1, int(recent_limit))
def register(
self,
*,
request_id: str,
api_mode: str,
backend: str,
media_type: str,
response_streaming: bool,
deadline_ts: float | None = None,
meta: Optional[Dict[str, Any]] = None,
) -> EngineRequestState:
now = time.perf_counter()
state = EngineRequestState(
request_id=request_id,
api_mode=api_mode,
backend=backend,
media_type=media_type,
response_streaming=bool(response_streaming),
submit_ts=now,
deadline_ts=deadline_ts,
updated_ts=now,
meta=dict(meta or {}),
lifecycle_timestamps={EngineStatus.NEW: now},
)
with self.lock:
self.active_requests[request_id] = state
return state
def _move_to_recent_locked(self, state: EngineRequestState) -> None:
self.recent_requests.appendleft(state)
while len(self.recent_requests) > self.recent_limit:
self.recent_requests.pop()
@staticmethod
def _apply_state_extra(state: EngineRequestState, extra: Optional[Dict[str, Any]]) -> None:
if not extra:
return
payload = dict(extra)
backend = payload.pop("backend", None)
if backend is not None:
state.backend = str(backend)
finish_reason = payload.pop("finish_reason", None)
if finish_reason is not None:
state.finish_reason = str(finish_reason)
error = payload.pop("error", None)
if error is not None:
state.error = str(error)
state.profile.update(payload)
def update(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None:
now = time.perf_counter()
with self.lock:
state = self.active_requests.get(request_id)
if state is None:
return
state.status = str(status)
state.updated_ts = now
state.lifecycle_timestamps[str(status)] = now
self._apply_state_extra(state, extra)
def merge_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
if not extra:
return
now = time.perf_counter()
with self.lock:
state = self.active_requests.get(request_id)
if state is None:
for recent_state in self.recent_requests:
if recent_state.request_id == request_id:
state = recent_state
break
if state is None:
return
state.updated_ts = now
self._apply_state_extra(state, extra)
def complete(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
now = time.perf_counter()
with self.lock:
state = self.active_requests.pop(request_id, None)
if state is None:
return
state.status = EngineStatus.COMPLETED
state.updated_ts = now
state.lifecycle_timestamps[EngineStatus.COMPLETED] = now
self._apply_state_extra(state, extra)
self._move_to_recent_locked(state)
def fail(self, request_id: str, error: str) -> None:
now = time.perf_counter()
with self.lock:
state = self.active_requests.pop(request_id, None)
if state is None:
return
state.status = EngineStatus.FAILED
state.updated_ts = now
state.error = str(error)
state.lifecycle_timestamps[EngineStatus.FAILED] = now
self._move_to_recent_locked(state)
def snapshot(self) -> Dict[str, Any]:
with self.lock:
active = [state.to_summary() for state in self.active_requests.values()]
recent = [state.to_summary() for state in list(self.recent_requests)]
recent_limit = self.recent_limit
active.sort(key=lambda item: item["submit_ts"])
return {
"active_count": len(active),
"recent_count": len(recent),
"recent_limit": recent_limit,
"active_requests": active,
"recent_requests": recent,
}
def collect_summaries(self, request_ids: Sequence[str]) -> list[Dict[str, Any]]:
requested = set(request_ids)
results: list[Dict[str, Any]] = []
with self.lock:
for state in self.active_requests.values():
if state.request_id in requested:
results.append(state.to_summary())
existing_ids = {item["request_id"] for item in results}
for state in self.recent_requests:
if state.request_id in requested and state.request_id not in existing_ids:
results.append(state.to_summary())
results.sort(key=lambda item: item["request_id"])
return results
def has_active(self, request_id: str) -> bool:
with self.lock:
return request_id in self.active_requests
@dataclass
class SchedulerPendingJob:
request_id: str
state: T2SRequestState
done_event: threading.Event
done_loop: asyncio.AbstractEventLoop | None
done_future: asyncio.Future | None
enqueue_time: float
speed_factor: float
sample_steps: int
media_type: str
admission_wait_ms: float = 0.0
engine_policy_wait_ms: float = 0.0
engine_dispatch_wait_ms: float = 0.0
prepare_wall_ms: float = 0.0
prepare_profile_total_ms: float = 0.0
first_schedule_time: float | None = None
prefill_ms: float = 0.0
merge_ms: float = 0.0
decode_ms: float = 0.0
finalize_wait_ms: float = 0.0
synth_ms: float = 0.0
pack_ms: float = 0.0
decode_steps: int = 0
result_ready_time: float | None = None
result: dict | None = None
sample_rate: int | None = None
audio_data: np.ndarray | None = None
error: str | None = None
engine_request_id: str | None = None
class SchedulerJobRegistry:
def __init__(self, lock: threading.Lock | threading.RLock | threading.Condition) -> None:
self._lock = lock
self._job_map: Dict[str, SchedulerPendingJob] = {}
self._total_submitted = 0
self._total_finished = 0
def register(self, job: SchedulerPendingJob, *, keep_job: bool = True) -> None:
with self._lock:
if keep_job:
self._job_map[job.request_id] = job
self._total_submitted += 1
def get(self, request_id: str) -> SchedulerPendingJob | None:
with self._lock:
return self._job_map.get(request_id)
def pop(self, request_id: str) -> SchedulerPendingJob | None:
with self._lock:
return self._job_map.pop(request_id, None)
def remove(self, request_id: str) -> None:
with self._lock:
self._job_map.pop(request_id, None)
def mark_finished(self) -> None:
with self._lock:
self._total_finished += 1
def mark_finished_and_remove(self, request_id: str) -> None:
with self._lock:
self._job_map.pop(request_id, None)
self._total_finished += 1
def is_empty(self) -> bool:
with self._lock:
return not self._job_map
def submitted_count(self) -> int:
with self._lock:
return int(self._total_submitted)
def finished_count(self) -> int:
with self._lock:
return int(self._total_finished)
def snapshot(self, max_request_ids: int = 32) -> Dict[str, Any]:
with self._lock:
request_ids = list(self._job_map.keys())
return {
"job_count": int(len(request_ids)),
"request_ids": request_ids[: max(0, int(max_request_ids))],
"total_submitted": int(self._total_submitted),
"total_finished": int(self._total_finished),
}

View File

@ -0,0 +1,334 @@
from __future__ import annotations
import asyncio
import threading
import time
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable, Deque, Dict, List, Optional, Sequence
from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PreparedCpuStage
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch, T2SFinishedItem, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import SchedulerPendingJob
class EngineTaskQueueOwner:
def __init__(self, completion_key: str = "total_completed") -> None:
self.condition = threading.Condition()
self.queue: Deque[Any] = deque()
self.total_submitted = 0
self.total_completed = 0
self.peak_waiting = 0
self.completion_key = str(completion_key)
def enqueue(self, item: Any) -> None:
with self.condition:
self.queue.append(item)
self.total_submitted += 1
self.peak_waiting = max(self.peak_waiting, len(self.queue))
self.condition.notify_all()
def enqueue_many(self, items: Sequence[Any]) -> None:
if not items:
return
with self.condition:
for item in items:
self.queue.append(item)
self.total_submitted += len(items)
self.peak_waiting = max(self.peak_waiting, len(self.queue))
self.condition.notify_all()
def pop_left(self) -> Any | None:
with self.condition:
if not self.queue:
return None
return self.queue.popleft()
def mark_completed(self, count: int = 1, *, notify: bool = False) -> None:
if count <= 0:
return
with self.condition:
self.total_completed += int(count)
if notify:
self.condition.notify_all()
def has_items(self) -> bool:
with self.condition:
return bool(self.queue)
def waiting_count(self) -> int:
with self.condition:
return int(len(self.queue))
def snapshot(self, *, max_request_ids: int = 16, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
with self.condition:
waiting_items = list(self.queue)[: max(0, int(max_request_ids))]
snapshot = {
"waiting_count": int(len(self.queue)),
"waiting_request_ids": [str(getattr(item, "request_id", "")) for item in waiting_items],
"peak_waiting": int(self.peak_waiting),
"total_submitted": int(self.total_submitted),
self.completion_key: int(self.total_completed),
}
if extra:
snapshot.update(dict(extra))
return snapshot
def peek_oldest_age_ms(self, timestamp_attr: str) -> float:
with self.condition:
if not self.queue:
return 0.0
enqueue_time = float(getattr(self.queue[0], timestamp_attr))
return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0)
def is_drained(self) -> bool:
with self.condition:
return not self.queue and self.total_submitted == self.total_completed
def take_finalize_batch(
self,
*,
finalize_mode: str,
batch_max_items: int,
batch_wait_s: float,
use_vocoder: bool,
) -> List[SchedulerFinalizeTask]:
with self.condition:
if not self.queue:
return []
selected_tasks = [self.queue.popleft()]
if finalize_mode == "sync" or use_vocoder:
return selected_tasks
if batch_max_items <= 1:
return selected_tasks
first_task = selected_tasks[0]
oldest_age_s = max(0.0, time.perf_counter() - first_task.enqueued_time)
if len(self.queue) + 1 < batch_max_items and oldest_age_s < batch_wait_s:
self.queue.appendleft(first_task)
return []
while len(selected_tasks) < batch_max_items:
if not self.queue:
break
matched_index = None
for index, task in enumerate(self.queue):
if abs(task.enqueued_time - first_task.enqueued_time) < 1.0:
matched_index = index
break
if matched_index is None:
break
selected_tasks.append(self.queue[matched_index])
del self.queue[matched_index]
return selected_tasks
@dataclass
class EngineDecodeRuntimeState:
pending_jobs: int = 0
pending_request_ids: List[str] = field(default_factory=list)
active_request_count: int = 0
active_request_ids: List[str] = field(default_factory=list)
prefill_done: bool = False
decode_step_index_max: int = 0
total_cycles: int = 0
prefill_cycles: int = 0
step_cycles: int = 0
has_work: bool = False
last_event: str = "init"
updated_at: float = 0.0
class EngineDecodeRuntimeOwner:
def __init__(
self,
*,
get_decode_runtime_counters: Callable[[], Dict[str, int]],
get_micro_batch_wait_s: Callable[[], float],
) -> None:
self.get_decode_runtime_counters = get_decode_runtime_counters
self.get_micro_batch_wait_s = get_micro_batch_wait_s
self.condition = threading.Condition()
self.pending_jobs: Deque[SchedulerPendingJob] = deque()
self.active_batch: T2SActiveBatch | None = None
self.state_lock = threading.Lock()
self.state = EngineDecodeRuntimeState(updated_at=time.perf_counter())
@staticmethod
def summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]:
if active_batch is None:
return {}
decode_step_index_max = 0
if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0:
decode_step_index_max = int(active_batch.step_indices.max().item())
return {
"request_count": int(len(active_batch.request_ids)),
"request_ids": list(active_batch.request_ids),
"prefill_done": bool(active_batch.prefill_done),
"decode_step_index_max": int(decode_step_index_max),
}
def snapshot_pending_queue_state(self) -> Dict[str, Any]:
with self.condition:
return {
"pending_jobs": int(len(self.pending_jobs)),
"pending_request_ids": [job.request_id for job in list(self.pending_jobs)[:32]],
}
def enqueue_pending_job(self, job: SchedulerPendingJob) -> None:
with self.condition:
self.pending_jobs.append(job)
self.condition.notify_all()
self.refresh_state("engine_decode_pending_enqueue")
def take_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
with self.condition:
if not self.pending_jobs:
return []
if wait_for_batch:
oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time)
if (time.perf_counter() - oldest_enqueue_time) < self.get_micro_batch_wait_s():
return []
pending_jobs = list(self.pending_jobs)
self.pending_jobs.clear()
self.refresh_state("engine_decode_pending_dequeue")
return pending_jobs
def pending_age_ms(self) -> float:
with self.condition:
if not self.pending_jobs:
return 0.0
enqueue_time = float(self.pending_jobs[0].enqueue_time)
return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0)
def has_pending_jobs(self) -> bool:
with self.condition:
return bool(self.pending_jobs)
def get_active_batch(self) -> T2SActiveBatch | None:
return self.active_batch
def set_active_batch(self, active_batch: T2SActiveBatch | None) -> None:
self.active_batch = active_batch
def active_batch_summary(self) -> Dict[str, Any]:
return self.summarize_active_batch(self.active_batch)
def refresh_state(self, last_event: str) -> None:
pending_state = self.snapshot_pending_queue_state()
active_batch_summary = self.active_batch_summary()
worker_decode_counters = self.get_decode_runtime_counters()
with self.state_lock:
self.state.pending_jobs = int(pending_state.get("pending_jobs", 0))
self.state.pending_request_ids = list(pending_state.get("pending_request_ids", []))
self.state.active_request_count = int(active_batch_summary.get("request_count", 0))
self.state.active_request_ids = list(active_batch_summary.get("request_ids", []))[:32]
self.state.prefill_done = bool(active_batch_summary.get("prefill_done", False))
self.state.decode_step_index_max = int(active_batch_summary.get("decode_step_index_max", 0))
self.state.total_cycles = int(worker_decode_counters.get("total_cycles", 0))
self.state.prefill_cycles = int(worker_decode_counters.get("prefill_cycles", 0))
self.state.step_cycles = int(worker_decode_counters.get("step_cycles", 0))
self.state.has_work = bool(pending_state.get("pending_jobs", 0) or active_batch_summary.get("request_count", 0))
self.state.last_event = str(last_event)
self.state.updated_at = float(time.perf_counter())
def update_from_worker_snapshot(self, snapshot: Dict[str, Any]) -> None:
if not snapshot:
return
pending_state = self.snapshot_pending_queue_state()
with self.state_lock:
self.state.pending_jobs = int(pending_state.get("pending_jobs", 0))
self.state.pending_request_ids = list(pending_state.get("pending_request_ids", []))
self.state.active_request_count = int(snapshot.get("active_request_count", 0))
self.state.active_request_ids = list(snapshot.get("active_request_ids", []))[:32]
self.state.prefill_done = bool(snapshot.get("prefill_done", False))
self.state.decode_step_index_max = int(snapshot.get("decode_step_index_max", 0))
self.state.total_cycles = int(snapshot.get("total_cycles", 0))
self.state.prefill_cycles = int(snapshot.get("prefill_cycles", 0))
self.state.step_cycles = int(snapshot.get("step_cycles", 0))
self.state.has_work = bool(
pending_state.get("pending_jobs", 0)
or snapshot.get("active_request_count", 0)
or snapshot.get("has_work", False)
)
self.state.last_event = str(snapshot.get("last_event", "unknown"))
self.state.updated_at = float(snapshot.get("updated_at", time.perf_counter()))
def snapshot_state(self) -> Dict[str, Any]:
pending_state = self.snapshot_pending_queue_state()
active_batch_summary = self.active_batch_summary()
worker_decode_counters = self.get_decode_runtime_counters()
with self.state_lock:
return {
"pending_jobs": int(pending_state.get("pending_jobs", self.state.pending_jobs)),
"pending_request_ids": list(pending_state.get("pending_request_ids", self.state.pending_request_ids)),
"active_request_count": int(active_batch_summary.get("request_count", self.state.active_request_count)),
"active_request_ids": list(active_batch_summary.get("request_ids", self.state.active_request_ids)),
"prefill_done": bool(active_batch_summary.get("prefill_done", self.state.prefill_done)),
"decode_step_index_max": int(active_batch_summary.get("decode_step_index_max", self.state.decode_step_index_max)),
"total_cycles": int(worker_decode_counters.get("total_cycles", 0)),
"prefill_cycles": int(worker_decode_counters.get("prefill_cycles", 0)),
"step_cycles": int(worker_decode_counters.get("step_cycles", 0)),
"has_work": bool(
pending_state.get("pending_jobs", 0)
or active_batch_summary.get("request_count", self.state.active_request_count)
or self.state.has_work
),
"last_event": str(self.state.last_event),
"updated_at": float(self.state.updated_at),
}
@dataclass
class SchedulerFinalizeTask:
request_id: str
item: T2SFinishedItem
enqueued_time: float
@dataclass
class EngineDispatchTask:
request_id: str
state: T2SRequestState
speed_factor: float
sample_steps: int
media_type: str
prepare_wall_ms: float
prepare_profile_total_ms: float
done_loop: asyncio.AbstractEventLoop | None
done_future: asyncio.Future | None
engine_request_id: str | None
timeout_sec: float | None
enqueue_time: float
worker_job: SchedulerPendingJob | None = None
engine_policy_wait_ms: float = 0.0
engine_dispatch_wait_ms: float = 0.0
engine_policy_snapshot: Dict[str, Any] | None = None
error: str | None = None
@dataclass
class EngineGpuPrepareTask:
request_id: str
cpu_stage: PreparedCpuStage
done_loop: asyncio.AbstractEventLoop | None
done_future: asyncio.Future | None
engine_request_id: str | None
enqueue_time: float
queue_wait_ms: float = 0.0
error: str | None = None
@dataclass
class EngineFinalizeQueueState:
waiting_count: int
waiting_request_ids: List[str]
peak_waiting: int
total_submitted: int
total_completed: int
@dataclass
class RuntimeStateCallbacks:
update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None
complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None
fail: Callable[[str, str], None] | None = None
decode_runtime_update: Callable[[Dict[str, Any]], None] | None = None

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,198 @@
from __future__ import annotations
import threading
import time
from typing import Any, Callable, Dict, List, Optional
import numpy as np
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerJobRegistry, SchedulerPendingJob
class WorkerCompletionBridge:
def __init__(self, runtime_callbacks: RuntimeStateCallbacks | None = None) -> None:
self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks()
@staticmethod
def _resolve_done_future(job: SchedulerPendingJob) -> None:
future = job.done_future
if future is None or future.done():
return
future.set_result(job)
def notify_done_future(self, job: SchedulerPendingJob) -> None:
if job.done_loop is None or job.done_future is None:
return
try:
job.done_loop.call_soon_threadsafe(self._resolve_done_future, job)
except RuntimeError:
pass
def runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None:
if request_id is None or self.runtime_callbacks.complete is None:
return
self.runtime_callbacks.complete(request_id, extra)
def runtime_fail(self, request_id: str | None, error: str) -> None:
if request_id is None or self.runtime_callbacks.fail is None:
return
self.runtime_callbacks.fail(request_id, error)
@staticmethod
def build_completed_job_result(
job: SchedulerPendingJob,
item: T2SFinishedItem,
*,
sample_rate: int,
audio_data: np.ndarray,
finished_at: float | None = None,
) -> Dict[str, Any]:
finished_at = float(time.perf_counter() if finished_at is None else finished_at)
queue_wait_ms = 0.0
if job.first_schedule_time is not None:
queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0)
worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0)
worker_residual_ms = max(
0.0,
worker_total_ms
- queue_wait_ms
- job.prefill_ms
- job.merge_ms
- job.decode_ms
- job.finalize_wait_ms
- job.synth_ms,
)
worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms)
job.sample_rate = int(sample_rate)
job.audio_data = audio_data
job.result_ready_time = finished_at
prepare_profile = dict(job.state.prepare_profile)
result = {
"request_id": item.request_id,
"semantic_len": int(item.semantic_tokens.shape[0]),
"finish_idx": int(item.finish_idx),
"finish_reason": item.finish_reason,
"decode_admission_wait_ms": float(job.admission_wait_ms),
"engine_policy_wait_ms": float(job.engine_policy_wait_ms),
"engine_dispatch_wait_ms": float(job.engine_dispatch_wait_ms),
"prepare_ms": job.prepare_wall_ms,
"prepare_wall_ms": job.prepare_wall_ms,
"prepare_profile_total_ms": job.prepare_profile_total_ms,
"prepare_profile": prepare_profile,
"queue_wait_ms": queue_wait_ms,
"prefill_ms": job.prefill_ms,
"merge_ms": job.merge_ms,
"decode_ms": job.decode_ms,
"finalize_wait_ms": job.finalize_wait_ms,
"synth_ms": job.synth_ms,
"worker_residual_ms": worker_residual_ms,
"worker_other_ms": worker_other_ms,
"worker_total_ms": worker_total_ms,
"decode_steps": int(job.decode_steps),
"sample_rate": int(sample_rate),
"media_type": job.media_type,
}
job.result = result
return result
@staticmethod
def build_runtime_complete_payload(
job: SchedulerPendingJob,
item: T2SFinishedItem,
*,
sample_rate: int,
) -> Dict[str, Any]:
return {
"finish_reason": item.finish_reason,
"semantic_len": int(item.semantic_tokens.shape[0]),
"finish_idx": int(item.finish_idx),
"sample_rate": int(sample_rate),
"worker_profile": dict(job.result or {}),
}
def complete_job(
self,
job: SchedulerPendingJob,
*,
runtime_request_id: str | None,
runtime_extra: Optional[Dict[str, Any]] = None,
remove_job: Callable[[], None] | None = None,
on_job_finished: Callable[[], None] | None = None,
notify_waiters: Callable[[], None] | None = None,
) -> None:
job.done_event.set()
self.notify_done_future(job)
if remove_job is not None:
remove_job()
if on_job_finished is not None:
on_job_finished()
if notify_waiters is not None:
notify_waiters()
self.runtime_complete(runtime_request_id, runtime_extra)
def fail_job(
self,
job: SchedulerPendingJob,
*,
error: str,
remove_job: Callable[[], None] | None = None,
on_job_finished: Callable[[], None] | None = None,
notify_waiters: Callable[[], None] | None = None,
) -> None:
job.error = str(error)
job.done_event.set()
self.notify_done_future(job)
if remove_job is not None:
remove_job()
if on_job_finished is not None:
on_job_finished()
if notify_waiters is not None:
notify_waiters()
self.runtime_fail(job.engine_request_id, str(error))
def complete_finalize_task(
self,
*,
condition: threading.Condition,
job_registry: SchedulerJobRegistry,
job: SchedulerPendingJob,
item: T2SFinishedItem,
sample_rate: int,
audio_data: np.ndarray,
) -> None:
runtime_extra: Optional[Dict[str, Any]] = None
with condition:
if job_registry.get(item.request_id) is not job:
return
self.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data)
runtime_extra = self.build_runtime_complete_payload(job, item, sample_rate=sample_rate)
self.complete_job(
job,
runtime_request_id=job.engine_request_id,
runtime_extra=runtime_extra,
on_job_finished=lambda: job_registry.mark_finished_and_remove(item.request_id),
notify_waiters=condition.notify_all,
)
def fail_jobs(
self,
*,
condition: threading.Condition,
job_registry: SchedulerJobRegistry,
request_ids: List[str],
error: str,
) -> None:
if not request_ids:
return
with condition:
for request_id in request_ids:
job = job_registry.get(request_id)
if job is None:
continue
self.fail_job(
job,
error=error,
on_job_finished=lambda rid=request_id: job_registry.mark_finished_and_remove(rid),
)
condition.notify_all()

View File

@ -0,0 +1,430 @@
from __future__ import annotations
import threading
import time
from typing import Any, Callable, Dict, List, Optional
import torch
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import (
T2SActiveBatch,
T2SFinishedItem,
decode_one_step,
merge_active_batches,
run_prefill_active_batch,
)
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerPendingJob
class WorkerDecodeExecutor:
def __init__(self, tts: TTS, max_steps: int) -> None:
self.tts = tts
self.max_steps = int(max_steps)
def _sync_device(self) -> None:
try:
device_str = str(self.tts.configs.device)
if device_str.startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize(self.tts.configs.device)
elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"):
torch.mps.synchronize()
except Exception:
pass
def execute_prefill_merge(
self,
*,
pending_jobs: List[SchedulerPendingJob],
active_batch: Optional[T2SActiveBatch],
mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None],
add_prefill_time: Callable[[List[str], float], None] | None,
add_merge_time: Callable[[List[str], float], None] | None,
enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None,
finalize_error: Callable[[List[str], str], None] | None,
) -> Dict[str, Any]:
if not pending_jobs:
return {
"executed": False,
"active_batch": active_batch,
"pending_jobs": [],
"prefill_elapsed_s": 0.0,
"merge_elapsed_s": 0.0,
"finished_items": [],
"error": None,
"error_request_ids": [],
}
admitted_finished: List[T2SFinishedItem] = []
prefill_elapsed_s = 0.0
merge_elapsed_s = 0.0
error: str | None = None
error_request_ids: List[str] = []
try:
self._sync_device()
prefill_start = time.perf_counter()
mark_prefill_started(pending_jobs, prefill_start)
admitted_active_batch, admitted_finished = run_prefill_active_batch(
self.tts.t2s_model.model,
[job.state for job in pending_jobs],
max_steps=self.max_steps,
)
self._sync_device()
prefill_elapsed_s = time.perf_counter() - prefill_start
if add_prefill_time is not None:
add_prefill_time([job.request_id for job in pending_jobs], prefill_elapsed_s)
if enqueue_finished is not None:
enqueue_finished(admitted_finished)
merge_start = time.perf_counter()
active_batch = merge_active_batches(
self.tts.t2s_model.model,
active_batch,
admitted_active_batch,
)
merge_elapsed_s = time.perf_counter() - merge_start
if add_merge_time is not None:
add_merge_time(
[] if active_batch is None else list(active_batch.request_ids),
merge_elapsed_s,
)
except Exception as exc:
error = str(exc)
error_request_ids = [job.request_id for job in pending_jobs]
if finalize_error is not None:
finalize_error(error_request_ids, error)
return {
"executed": True,
"active_batch": active_batch,
"pending_jobs": list(pending_jobs),
"prefill_elapsed_s": float(prefill_elapsed_s),
"merge_elapsed_s": float(merge_elapsed_s),
"finished_items": list(admitted_finished),
"error": error,
"error_request_ids": error_request_ids,
}
def execute_decode_step(
self,
*,
active_batch: Optional[T2SActiveBatch],
add_decode_time: Callable[[List[str], float], None] | None,
enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None,
finalize_error: Callable[[List[str], str], None] | None,
) -> Dict[str, Any]:
if active_batch is None:
return {
"executed": False,
"active_batch": None,
"request_ids": [],
"decode_elapsed_s": 0.0,
"finished_items": [],
"error": None,
"error_request_ids": [],
}
active_request_ids: List[str] = []
step_finished: List[T2SFinishedItem] = []
decode_elapsed_s = 0.0
error: str | None = None
error_request_ids: List[str] = []
try:
active_request_ids = [state.request_id for state in active_batch.states]
self._sync_device()
decode_start = time.perf_counter()
active_batch, step_finished = decode_one_step(
self.tts.t2s_model.model,
active_batch,
max_steps=self.max_steps,
)
self._sync_device()
decode_elapsed_s = time.perf_counter() - decode_start
if add_decode_time is not None:
add_decode_time(active_request_ids, decode_elapsed_s)
if enqueue_finished is not None:
enqueue_finished(step_finished)
except Exception as exc:
error = str(exc)
error_request_ids = list(active_request_ids)
if finalize_error is not None:
finalize_error(error_request_ids, error)
active_batch = None
return {
"executed": True,
"active_batch": active_batch,
"request_ids": active_request_ids,
"decode_elapsed_s": float(decode_elapsed_s),
"finished_items": list(step_finished),
"error": error,
"error_request_ids": error_request_ids,
}
def execute_decode_cycle(
self,
*,
pending_jobs: List[SchedulerPendingJob],
active_batch: Optional[T2SActiveBatch],
mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None],
add_prefill_time: Callable[[List[str], float], None] | None,
add_merge_time: Callable[[List[str], float], None] | None,
add_decode_time: Callable[[List[str], float], None] | None,
enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None,
finalize_error: Callable[[List[str], str], None] | None,
) -> Dict[str, Any]:
result = {
"executed": False,
"prefill_merge_executed": False,
"decode_step_executed": False,
"active_batch": active_batch,
"prefill_phase": {},
"decode_phase": {},
}
prefill_phase = self.execute_prefill_merge(
pending_jobs=list(pending_jobs),
active_batch=result["active_batch"],
mark_prefill_started=mark_prefill_started,
add_prefill_time=add_prefill_time,
add_merge_time=add_merge_time,
enqueue_finished=enqueue_finished,
finalize_error=finalize_error,
)
prefill_executed = bool(prefill_phase.get("executed", False))
result["prefill_phase"] = prefill_phase
result["active_batch"] = prefill_phase.get("active_batch")
if prefill_executed:
result["executed"] = True
result["prefill_merge_executed"] = True
decode_phase = self.execute_decode_step(
active_batch=result["active_batch"],
add_decode_time=add_decode_time,
enqueue_finished=enqueue_finished,
finalize_error=finalize_error,
)
decode_executed = bool(decode_phase.get("executed", False))
result["decode_phase"] = decode_phase
result["active_batch"] = decode_phase.get("active_batch")
if decode_executed:
result["executed"] = True
result["decode_step_executed"] = True
return result
class WorkerDecodeLegacyShell:
def __init__(self, condition: threading.Condition, micro_batch_wait_s: float) -> None:
self.condition = condition
self.micro_batch_wait_s = float(micro_batch_wait_s)
self.pending_jobs: List[SchedulerPendingJob] = []
self.active_batch: T2SActiveBatch | None = None
@staticmethod
def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any] | None:
if active_batch is None:
return None
return {
"request_count": int(len(active_batch.request_ids)),
"request_ids": list(active_batch.request_ids),
"prefill_done": bool(active_batch.prefill_done),
"decode_step_index_max": (
int(active_batch.step_indices.max().item())
if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0
else 0
),
}
def current_backlog_locked(self) -> int:
running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids)
return int(len(self.pending_jobs) + running_requests)
def enqueue_pending_job_locked(self, job: SchedulerPendingJob) -> None:
self.pending_jobs.append(job)
def snapshot_locked(self) -> Dict[str, Any]:
active_batch_summary = self._summarize_active_batch(self.active_batch)
executor_local_pending_jobs = int(len(self.pending_jobs))
executor_local_running_requests = 0 if self.active_batch is None else int(len(self.active_batch.request_ids))
executor_local_has_work = bool(self.pending_jobs or self.active_batch is not None)
return {
"executor_local_pending_jobs": executor_local_pending_jobs,
"executor_local_running_requests": executor_local_running_requests,
"executor_local_has_work": executor_local_has_work,
"executor_local_active_batch": active_batch_summary,
}
def is_idle_locked(self) -> bool:
return self.active_batch is None and not self.pending_jobs
def take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
with self.condition:
if not self.pending_jobs and self.active_batch is None:
self.condition.wait(timeout=self.micro_batch_wait_s)
elif wait_for_batch and self.pending_jobs:
self.condition.wait(timeout=self.micro_batch_wait_s)
if not self.pending_jobs:
return []
pending = list(self.pending_jobs)
self.pending_jobs.clear()
return pending
def take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
with self.condition:
if not self.pending_jobs:
return []
if wait_for_batch:
oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time)
if (time.perf_counter() - oldest_enqueue_time) < self.micro_batch_wait_s:
return []
pending = list(self.pending_jobs)
self.pending_jobs.clear()
return pending
def has_decode_runtime_work(self) -> bool:
with self.condition:
return bool(self.pending_jobs or self.active_batch is not None)
def build_runtime_summary_locked(self, *, total_cycles: int, prefill_cycles: int, step_cycles: int, last_event: str) -> Dict[str, Any]:
active_request_ids = [] if self.active_batch is None else list(self.active_batch.request_ids)
decode_step_index_max = 0
prefill_done = False
if self.active_batch is not None:
prefill_done = bool(self.active_batch.prefill_done)
if self.active_batch.step_indices is not None and self.active_batch.step_indices.numel() > 0:
decode_step_index_max = int(self.active_batch.step_indices.max().item())
return {
"pending_jobs": int(len(self.pending_jobs)),
"active_request_count": int(len(active_request_ids)),
"active_request_ids": active_request_ids[:32],
"prefill_done": bool(prefill_done),
"decode_step_index_max": int(decode_step_index_max),
"total_cycles": int(total_cycles),
"prefill_cycles": int(prefill_cycles),
"step_cycles": int(step_cycles),
"has_work": bool(self.pending_jobs or self.active_batch is not None),
"last_event": str(last_event),
"updated_at": float(time.perf_counter()),
}
def run_prefill_merge_once_nonblocking(
self,
*,
external_pending_jobs: Optional[List[SchedulerPendingJob]],
external_active_batch: Optional[T2SActiveBatch],
execute_prefill_merge: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]],
) -> Dict[str, Any]:
pending_jobs = (
list(external_pending_jobs)
if external_pending_jobs is not None
else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None)
)
active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch
result = execute_prefill_merge(pending_jobs, active_batch)
if external_pending_jobs is None:
with self.condition:
self.active_batch = result.get("active_batch")
self.condition.notify_all()
return result
def run_decode_step_once_nonblocking(
self,
*,
external_active_batch: Optional[T2SActiveBatch],
execute_decode_step: Callable[[Optional[T2SActiveBatch]], Dict[str, Any]],
) -> Dict[str, Any]:
active_batch = self.active_batch if external_active_batch is None else external_active_batch
result = execute_decode_step(active_batch)
if external_active_batch is None:
with self.condition:
self.active_batch = result.get("active_batch")
self.condition.notify_all()
return result
def run_decode_cycle_nonblocking(
self,
*,
external_pending_jobs: Optional[List[SchedulerPendingJob]],
external_active_batch: Optional[T2SActiveBatch],
execute_decode_cycle: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]],
on_cycle_executed: Callable[[Dict[str, Any]], None] | None,
) -> Dict[str, Any]:
pending_jobs = (
list(external_pending_jobs)
if external_pending_jobs is not None
else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None)
)
active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch
result = execute_decode_cycle(pending_jobs, active_batch)
if external_pending_jobs is None:
with self.condition:
self.active_batch = result.get("active_batch")
self.condition.notify_all()
if result.get("executed") and on_cycle_executed is not None:
on_cycle_executed(result)
return result
def run_loop(
self,
*,
run_decode_cycle_nonblocking: Callable[[], Dict[str, Any]],
) -> None:
while True:
executed = run_decode_cycle_nonblocking()
if executed.get("executed"):
continue
wait_for_batch = self.active_batch is None
pending_jobs = self.take_pending_snapshot(wait_for_batch=wait_for_batch)
if pending_jobs:
with self.condition:
self.pending_jobs = pending_jobs + self.pending_jobs
self.condition.notify_all()
continue
time.sleep(self.micro_batch_wait_s)
class WorkerDecodeRuntimeTracker:
def __init__(
self,
runtime_callbacks: RuntimeStateCallbacks | None = None,
) -> None:
self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks()
self.total_cycles = 0
self.prefill_cycles = 0
self.step_cycles = 0
def get_counters(self) -> Dict[str, int]:
return {
"total_cycles": int(self.total_cycles),
"prefill_cycles": int(self.prefill_cycles),
"step_cycles": int(self.step_cycles),
}
def record_cycle(self, result: Dict[str, Any]) -> None:
if not bool(result.get("executed")):
return
self.total_cycles += 1
if bool(result.get("prefill_merge_executed")):
self.prefill_cycles += 1
if bool(result.get("decode_step_executed")):
self.step_cycles += 1
def build_runtime_summary_locked(
self,
*,
legacy_shell: WorkerDecodeLegacyShell,
last_event: str,
) -> Dict[str, Any]:
return legacy_shell.build_runtime_summary_locked(
total_cycles=int(self.total_cycles),
prefill_cycles=int(self.prefill_cycles),
step_cycles=int(self.step_cycles),
last_event=str(last_event),
)
def notify_runtime_update_locked(
self,
*,
legacy_shell: WorkerDecodeLegacyShell,
last_event: str,
) -> None:
if self.runtime_callbacks.decode_runtime_update is None:
return
snapshot = self.build_runtime_summary_locked(
legacy_shell=legacy_shell,
last_event=last_event,
)
self.runtime_callbacks.decode_runtime_update(snapshot)

View File

@ -0,0 +1,164 @@
from __future__ import annotations
import time
from typing import Any, Dict, List, Optional
import numpy as np
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch, T2SFinishedItem
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob
class WorkerExecutionMixin:
def execute_prefill_merge(
self,
pending_jobs: List[SchedulerPendingJob],
active_batch: Optional[T2SActiveBatch],
external_bookkeeping: bool = False,
) -> Dict[str, Any]:
return self.decode_executor.execute_prefill_merge(
pending_jobs=pending_jobs,
active_batch=active_batch,
mark_prefill_started=self._mark_prefill_started,
add_prefill_time=None if external_bookkeeping else self._add_prefill_time,
add_merge_time=None if external_bookkeeping else self._add_merge_time,
enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished,
finalize_error=None if external_bookkeeping else self._finalize_error,
)
def execute_decode_step(
self,
active_batch: Optional[T2SActiveBatch],
external_bookkeeping: bool = False,
) -> Dict[str, Any]:
return self.decode_executor.execute_decode_step(
active_batch=active_batch,
add_decode_time=None if external_bookkeeping else self._add_decode_time,
enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished,
finalize_error=None if external_bookkeeping else self._finalize_error,
)
def execute_decode_cycle(
self,
pending_jobs: List[SchedulerPendingJob],
active_batch: Optional[T2SActiveBatch],
external_bookkeeping: bool = False,
) -> Dict[str, Any]:
result = self.decode_executor.execute_decode_cycle(
pending_jobs=pending_jobs,
active_batch=active_batch,
mark_prefill_started=self._mark_prefill_started,
add_prefill_time=None if external_bookkeeping else self._add_prefill_time,
add_merge_time=None if external_bookkeeping else self._add_merge_time,
add_decode_time=None if external_bookkeeping else self._add_decode_time,
enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished,
finalize_error=None if external_bookkeeping else self._finalize_error,
)
self._record_decode_runtime_cycle(result)
return result
def run_prefill_merge_once_nonblocking(
self,
external_pending_jobs: Optional[List[SchedulerPendingJob]] = None,
external_active_batch: Optional[T2SActiveBatch] = None,
emit_runtime_state: bool = True,
external_bookkeeping: bool = False,
) -> Dict[str, Any]:
result = self.decode_legacy_shell.run_prefill_merge_once_nonblocking(
external_pending_jobs=external_pending_jobs,
external_active_batch=external_active_batch,
execute_prefill_merge=lambda batch_jobs, batch_state: self.execute_prefill_merge(
pending_jobs=batch_jobs,
active_batch=batch_state,
external_bookkeeping=external_bookkeeping,
),
)
if emit_runtime_state:
self._notify_decode_runtime_state("prefill_merge")
return result
def run_decode_step_once_nonblocking(
self,
external_active_batch: Optional[T2SActiveBatch] = None,
emit_runtime_state: bool = True,
external_bookkeeping: bool = False,
) -> Dict[str, Any]:
result = self.decode_legacy_shell.run_decode_step_once_nonblocking(
external_active_batch=external_active_batch,
execute_decode_step=lambda batch_state: self.execute_decode_step(
active_batch=batch_state,
external_bookkeeping=external_bookkeeping,
),
)
if emit_runtime_state:
self._notify_decode_runtime_state("decode_step")
return result
def run_decode_cycle_nonblocking(
self,
external_pending_jobs: Optional[List[SchedulerPendingJob]] = None,
external_active_batch: Optional[T2SActiveBatch] = None,
emit_runtime_state: bool = True,
external_bookkeeping: bool = False,
) -> Dict[str, Any]:
result = self.decode_legacy_shell.run_decode_cycle_nonblocking(
external_pending_jobs=external_pending_jobs,
external_active_batch=external_active_batch,
execute_decode_cycle=lambda batch_jobs, batch_state: self.execute_decode_cycle(
pending_jobs=batch_jobs,
active_batch=batch_state,
external_bookkeeping=external_bookkeeping,
),
on_cycle_executed=None,
)
if result.get("executed") and emit_runtime_state:
self._notify_decode_runtime_state("decode_cycle")
return result
def execute_finalize_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None:
if not tasks:
return
try:
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
with self.condition:
for task in tasks:
job = self.job_registry.get(task.request_id)
if job is None:
continue
jobs_and_items.append((job, task.item))
if not jobs_and_items:
return
now = time.perf_counter()
for task in tasks:
self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0))
for job, item in jobs_and_items:
self._runtime_update(
job.engine_request_id,
EngineStatus.FINALIZING,
{
"finish_reason": item.finish_reason,
"semantic_len": int(item.semantic_tokens.shape[0]),
},
)
synth_ms, batch_results = self.synthesize_finalize_jobs(jobs_and_items)
with self.condition:
for job, _ in jobs_and_items:
tracked_job = self.job_registry.get(job.request_id)
if tracked_job is not None:
tracked_job.synth_ms += synth_ms
for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results):
self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data)
except Exception as exc:
self._finalize_error([task.request_id for task in tasks], str(exc))
finally:
self.finalize_executor.end_execution(len(tasks))
def _run_finalize_loop(self) -> None:
while True:
tasks = self.finalize_executor.take_task_batch_blocking()
self.execute_finalize_tasks(tasks)
def _run_loop(self) -> None:
self.decode_legacy_shell.run_loop(
run_decode_cycle_nonblocking=lambda: self.run_decode_cycle_nonblocking()
)

View File

@ -0,0 +1,234 @@
from __future__ import annotations
import os
import threading
import time
from collections import deque
from typing import Any, Callable, Deque, Dict, List
import numpy as np
import torch
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import SchedulerFinalizeTask, SchedulerPendingJob
class WorkerFinalizeExecutor:
def __init__(
self,
tts: TTS,
on_state_change: Callable[[], None] | None = None,
external_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None,
) -> None:
self.tts = tts
self.on_state_change = on_state_change
self.external_submit = external_submit
self.condition = threading.Condition()
self.pending_tasks: Deque[SchedulerFinalizeTask] = deque()
self.pending_peak = 0
self.inflight = 0
self.inflight_peak = 0
self.worker_count = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1)))
self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower()
self.batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16)))
self.batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0)
def _notify_state_change(self) -> None:
if self.on_state_change is None:
return
try:
self.on_state_change()
except Exception:
pass
def get_worker_count(self) -> int:
return int(self.worker_count)
def get_batch_policy(self) -> Dict[str, Any]:
return {
"finalize_mode": str(self.finalize_mode),
"finalize_batch_max_items": int(self.batch_max_items),
"finalize_batch_wait_s": float(self.batch_wait_s),
}
def get_pending_count(self) -> int:
with self.condition:
return int(len(self.pending_tasks))
def snapshot(self) -> Dict[str, Any]:
with self.condition:
return {
"finalize_pending": int(len(self.pending_tasks)),
"finalize_pending_peak": int(self.pending_peak),
"finalize_inflight": int(self.inflight),
"finalize_inflight_peak": int(self.inflight_peak),
"finalize_workers": int(self.worker_count),
"finalize_mode": str(self.finalize_mode),
"finalize_batch_max_items": int(self.batch_max_items),
"finalize_batch_wait_ms": float(self.batch_wait_s * 1000.0),
}
def is_idle(self) -> bool:
with self.condition:
return self.inflight <= 0 and not self.pending_tasks
def enqueue_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None:
if not tasks:
return
if self.external_submit is not None:
self.external_submit(tasks)
self._notify_state_change()
return
with self.condition:
for task in tasks:
self.pending_tasks.append(task)
self.pending_peak = max(self.pending_peak, len(self.pending_tasks))
self.condition.notify_all()
self._notify_state_change()
def begin_execution(self, task_count: int) -> None:
if task_count <= 0:
return
with self.condition:
self.inflight += int(task_count)
self.inflight_peak = max(self.inflight_peak, self.inflight)
self.condition.notify_all()
self._notify_state_change()
def end_execution(self, task_count: int) -> None:
with self.condition:
self.inflight = max(0, self.inflight - int(task_count))
self.condition.notify_all()
self._notify_state_change()
def take_task_batch_blocking(self) -> List[SchedulerFinalizeTask]:
with self.condition:
while not self.pending_tasks:
self.condition.wait()
selected_tasks = [self.pending_tasks.popleft()]
if self.finalize_mode == "sync" or self.tts.configs.use_vocoder:
self.inflight += len(selected_tasks)
self.inflight_peak = max(self.inflight_peak, self.inflight)
self._notify_state_change()
return selected_tasks
batch_deadline = time.perf_counter() + self.batch_wait_s
while len(selected_tasks) < self.batch_max_items:
if not self.pending_tasks:
remaining = batch_deadline - time.perf_counter()
if remaining <= 0:
break
self.condition.wait(timeout=remaining)
continue
first_task = selected_tasks[0]
matched_index = None
for index, task in enumerate(self.pending_tasks):
if abs(task.enqueued_time - first_task.enqueued_time) < 1.0:
matched_index = index
break
if matched_index is not None:
selected_tasks.append(self.pending_tasks[matched_index])
del self.pending_tasks[matched_index]
continue
remaining = batch_deadline - time.perf_counter()
if remaining <= 0:
break
self.condition.wait(timeout=remaining)
self.inflight += len(selected_tasks)
self.inflight_peak = max(self.inflight_peak, self.inflight)
self._notify_state_change()
return selected_tasks
def _sync_device(self) -> None:
try:
device_str = str(self.tts.configs.device)
if device_str.startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize(self.tts.configs.device)
elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"):
torch.mps.synchronize()
except Exception:
pass
def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]:
audio_fragment = self.tts.synthesize_audio_request_local(
semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0),
phones=job.state.phones.detach().clone().unsqueeze(0),
prompt_semantic=job.state.prompt_semantic.detach().clone(),
prompt_phones=job.state.prompt_phones.detach().clone(),
refer_spec=(
job.state.refer_spec[0].detach().clone(),
None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(),
),
raw_audio=job.state.raw_audio.detach().clone(),
raw_sr=int(job.state.raw_sr),
speed=float(job.speed_factor),
sample_steps=int(job.sample_steps),
)
output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"]
return self.tts.audio_postprocess(
audio=[[audio_fragment]],
sr=int(output_sr),
batch_index_list=None,
speed_factor=float(job.speed_factor),
split_bucket=False,
fragment_interval=0.0,
super_sampling=False,
)
def _synthesize_finished_audio_batch(
self,
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]],
) -> List[tuple[int, np.ndarray]]:
semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items]
phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items]
refer_specs = []
speeds = []
sample_steps_list = []
for job, _ in jobs_and_items:
refer_specs.append(
(
job.state.refer_spec[0].detach().clone(),
None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(),
)
)
speeds.append(float(job.speed_factor))
sample_steps_list.append(int(job.sample_steps))
audio_fragments = self.tts.synthesize_audio_requests_local_batched(
semantic_tokens_list=semantic_tokens_list,
phones_list=phones_list,
refer_specs=refer_specs,
speeds=speeds,
sample_steps_list=sample_steps_list,
)
output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"]
results: List[tuple[int, np.ndarray]] = []
for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments):
results.append(
self.tts.audio_postprocess(
audio=[[audio_fragment]],
sr=int(output_sr),
batch_index_list=None,
speed_factor=float(job.speed_factor),
split_bucket=False,
fragment_interval=0.0,
super_sampling=False,
)
)
return results
def synthesize_finalize_jobs(
self,
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]],
) -> tuple[float, List[tuple[int, np.ndarray]]]:
if not jobs_and_items:
return 0.0, []
self._sync_device()
synth_start = time.perf_counter()
if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder:
job, item = jobs_and_items[0]
batch_results = [self._synthesize_finished_audio(job, item)]
else:
batch_results = self._synthesize_finished_audio_batch(jobs_and_items)
self._sync_device()
synth_ms = (time.perf_counter() - synth_start) * 1000.0
return float(synth_ms), batch_results

View File

@ -0,0 +1,71 @@
from __future__ import annotations
import asyncio
import time
from typing import Callable, Dict, List
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState
class WorkerPrepareExecutor:
def __init__(
self,
tts: TTS,
on_state_change: Callable[[], None] | None = None,
) -> None:
self.coordinator = PrepareCoordinator(tts)
self.on_state_change = on_state_change
def _notify_state_change(self) -> None:
if self.on_state_change is None:
return
try:
self.on_state_change()
except Exception:
pass
def snapshot(self) -> Dict[str, int]:
return dict(self.coordinator.snapshot())
def get_max_inflight(self) -> int:
return int(self.coordinator.snapshot().get("max_inflight", 0))
def is_idle(self) -> bool:
return int(self.coordinator.snapshot().get("inflight", 0)) <= 0
async def prepare_state_profiled_async(
self,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
) -> tuple[T2SRequestState, float, float]:
try:
return await self.coordinator.prepare_state_profiled_async(spec, prepare_submit_at)
finally:
self._notify_state_change()
async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]:
results = await asyncio.gather(
*[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs]
)
return [state for state, _, _ in results]
async def prepare_cpu_stage_profiled_async(
self,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
) -> PreparedCpuStage:
try:
return await self.coordinator.prepare_cpu_stage_profiled_async(spec, prepare_submit_at)
finally:
self._notify_state_change()
async def prepare_gpu_stage_profiled_async(
self,
cpu_stage: PreparedCpuStage,
) -> tuple[T2SRequestState, float, float]:
try:
return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage)
finally:
self._notify_state_change()

View File

@ -0,0 +1,170 @@
from __future__ import annotations
import threading
import time
from typing import Any, Dict, List, Optional
import numpy as np
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob
class WorkerRuntimeBookkeepingMixin:
def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None:
with self.condition:
for job in pending_jobs:
job.first_schedule_time = float(started_at)
self._runtime_update(
job.engine_request_id,
EngineStatus.GPU_PREPARING,
{"scheduler_request_id": job.request_id, "prefill_started_at": float(started_at)},
)
def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
if not request_ids:
return
with self.condition:
for request_id in request_ids:
job = self.job_registry.get(request_id)
if job is not None:
job.prefill_ms += delta_ms
def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
if not request_ids:
return
with self.condition:
for request_id in request_ids:
job = self.job_registry.get(request_id)
if job is not None:
job.merge_ms += delta_ms
def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
if not request_ids:
return
activate_request_ids: List[str] = []
with self.condition:
for request_id in request_ids:
job = self.job_registry.get(request_id)
if job is not None:
if job.decode_steps == 0:
activate_request_ids.append(job.engine_request_id)
job.decode_ms += delta_ms
job.decode_steps += 1
for engine_request_id in activate_request_ids:
self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None)
def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None:
if not request_ids:
return
with self.condition:
for request_id in request_ids:
job = self.job_registry.get(request_id)
if job is not None:
job.finalize_wait_ms += float(delta_ms)
def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None:
if not items:
return
enqueued_at = time.perf_counter()
tasks: List[SchedulerFinalizeTask] = []
with self.condition:
for item in items:
job = self.job_registry.get(item.request_id)
if job is not None:
self._runtime_update(
job.engine_request_id,
EngineStatus.READY_FOR_FINALIZE,
{
"finish_reason": item.finish_reason,
"semantic_len": int(item.semantic_tokens.shape[0]),
"finish_idx": int(item.finish_idx),
},
)
tasks.append(SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at))
self.finalize_executor.enqueue_tasks(tasks)
def begin_finalize_execution(self, task_count: int) -> None:
self.finalize_executor.begin_execution(task_count)
def end_finalize_execution(self, task_count: int) -> None:
self.finalize_executor.end_execution(task_count)
def record_external_job_done(self, request_id: str) -> None:
with self.condition:
self.job_registry.mark_finished_and_remove(request_id)
self.condition.notify_all()
def synthesize_finalize_jobs(
self,
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]],
) -> tuple[float, List[tuple[int, np.ndarray]]]:
return self.finalize_executor.synthesize_finalize_jobs(jobs_and_items)
def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None:
self.completion_bridge.complete_finalize_task(
condition=self.condition,
job_registry=self.job_registry,
job=job,
item=item,
sample_rate=sample_rate,
audio_data=audio_data,
)
def _finalize_error(self, request_ids: List[str], error: str) -> None:
self.completion_bridge.fail_jobs(
condition=self.condition,
job_registry=self.job_registry,
request_ids=request_ids,
error=error,
)
@staticmethod
def _resolve_done_future(job: SchedulerPendingJob) -> None:
future = job.done_future
if future is None or future.done():
return
future.set_result(job)
def _notify_done_future(self, job: SchedulerPendingJob) -> None:
self.completion_bridge.notify_done_future(job)
def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None:
if request_id is None or self.runtime_callbacks.update is None:
return
self.runtime_callbacks.update(request_id, status, extra)
def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None:
self.completion_bridge.runtime_complete(request_id, extra)
def _runtime_fail(self, request_id: str | None, error: str) -> None:
self.completion_bridge.runtime_fail(request_id, error)
def _build_decode_runtime_summary_locked(self, last_event: str) -> Dict[str, Any]:
return self.decode_runtime_tracker.build_runtime_summary_locked(
legacy_shell=self.decode_legacy_shell,
last_event=str(last_event),
)
def _notify_decode_runtime_state(self, last_event: str) -> None:
with self.condition:
self.decode_runtime_tracker.notify_runtime_update_locked(
legacy_shell=self.decode_legacy_shell,
last_event=str(last_event),
)
def _record_decode_runtime_cycle(self, result: Dict[str, Any]) -> None:
with self.condition:
self.decode_runtime_tracker.record_cycle(result)
def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
return self.decode_legacy_shell.take_pending_snapshot(wait_for_batch)
def _take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
return self.decode_legacy_shell.take_pending_snapshot_nonblocking(wait_for_batch)
def has_decode_runtime_work(self) -> bool:
return self.decode_legacy_shell.has_decode_runtime_work()

View File

@ -0,0 +1,256 @@
from __future__ import annotations
import asyncio
import threading
import time
from typing import Any, Dict, List
from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PreparedCpuStage
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerPendingJob
class WorkerSubmitLifecycleMixin:
def _current_decode_backlog_locked(self) -> int:
return self.decode_legacy_shell.current_backlog_locked()
def get_micro_batch_wait_s(self) -> float:
return float(self.micro_batch_wait_s)
def is_engine_decode_control_enabled(self) -> bool:
return bool(self.engine_decode_control_enabled)
def get_prepare_max_inflight(self) -> int:
return int(self.prepare_executor.get_max_inflight())
def get_capacity_limits(self) -> Dict[str, int]:
return {
"decode_backlog_max": int(self.decode_backlog_max),
"finalize_pending_max": int(self.finalize_pending_max),
}
def get_finalize_batch_policy(self) -> Dict[str, Any]:
return dict(self.finalize_executor.get_batch_policy())
def get_decode_runtime_counters(self) -> Dict[str, int]:
with self.condition:
return self.decode_runtime_tracker.get_counters()
def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]:
decode_backlog = self._current_decode_backlog_locked()
finalize_pending = int(self.finalize_executor.get_pending_count())
prepare_inflight = int(self.prepare_executor.snapshot()["inflight"])
blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max
blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max
return (
not blocked_decode and not blocked_finalize,
{
"decode_backlog": decode_backlog,
"finalize_pending": finalize_pending,
"prepare_inflight": prepare_inflight,
"decode_backlog_max": int(self.decode_backlog_max),
"finalize_pending_max": int(self.finalize_pending_max),
},
)
def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]:
start = time.perf_counter()
deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec)))
while True:
with self.condition:
allowed, snapshot = self._can_accept_submit_locked()
if allowed:
return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot
if deadline is not None and time.perf_counter() >= deadline:
raise TimeoutError(
"scheduler submit admission timeout "
f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})"
)
self.condition.wait(timeout=self.micro_batch_wait_s)
def _admission_snapshot_locked(self) -> Dict[str, int]:
_, snapshot = self._can_accept_submit_locked()
return snapshot
async def submit_async(
self,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None = None,
done_future: asyncio.Future | None = None,
engine_request_id: str | None = None,
timeout_sec: float | None = None,
skip_capacity_wait: bool = False,
admission_wait_ms_override: float | None = None,
admission_snapshot_override: Dict[str, Any] | None = None,
engine_policy_wait_ms: float = 0.0,
engine_dispatch_wait_ms: float = 0.0,
enqueue_pending: bool = True,
) -> SchedulerPendingJob:
return await asyncio.to_thread(
self.submit,
state,
speed_factor,
sample_steps,
media_type,
prepare_wall_ms,
prepare_profile_total_ms,
done_loop,
done_future,
engine_request_id,
timeout_sec,
skip_capacity_wait,
admission_wait_ms_override,
admission_snapshot_override,
engine_policy_wait_ms,
engine_dispatch_wait_ms,
enqueue_pending,
)
def snapshot(self) -> dict:
with self.condition:
prepare_state = self.prepare_executor.snapshot()
finalize_state = self.finalize_executor.snapshot()
shell_state = self.decode_legacy_shell.snapshot_locked()
decode_runtime_counters = self.decode_runtime_tracker.get_counters()
engine_owned_decode_state = bool(self.engine_decode_control_enabled)
active_batch_summary = shell_state.get("executor_local_active_batch")
executor_local_pending_jobs = int(shell_state.get("executor_local_pending_jobs", 0))
executor_local_running_requests = int(shell_state.get("executor_local_running_requests", 0))
executor_local_has_work = bool(shell_state.get("executor_local_has_work", False))
return {
"pending_jobs": 0 if engine_owned_decode_state else executor_local_pending_jobs,
"running_requests": 0 if engine_owned_decode_state else executor_local_running_requests,
"engine_decode_control_enabled": bool(self.engine_decode_control_enabled),
"legacy_state_owner_mode": not engine_owned_decode_state,
"decode_state_owner": "engine" if engine_owned_decode_state else "worker",
"decode_runtime_has_work": False if engine_owned_decode_state else executor_local_has_work,
"executor_local_pending_jobs": executor_local_pending_jobs,
"executor_local_running_requests": executor_local_running_requests,
"executor_local_has_work": executor_local_has_work,
"decode_runtime_total_cycles": int(decode_runtime_counters.get("total_cycles", 0)),
"decode_runtime_prefill_cycles": int(decode_runtime_counters.get("prefill_cycles", 0)),
"decode_runtime_step_cycles": int(decode_runtime_counters.get("step_cycles", 0)),
"prepare_inflight": prepare_state["inflight"],
"prepare_peak_inflight": prepare_state["peak_inflight"],
"prepare_max_inflight": prepare_state.get("max_inflight", 0),
"prepare_state": dict(prepare_state),
**finalize_state,
"decode_backlog_max": self.decode_backlog_max,
"finalize_pending_max": self.finalize_pending_max,
"active_batch": {} if engine_owned_decode_state else active_batch_summary,
"executor_local_active_batch": active_batch_summary if engine_owned_decode_state else None,
"total_submitted": self.job_registry.submitted_count(),
"total_finished": self.job_registry.finished_count(),
"drained": self.is_drained(),
}
def is_drained(self) -> bool:
with self.condition:
return (
self.decode_legacy_shell.is_idle_locked()
and self.job_registry.is_empty()
and self.prepare_executor.is_idle()
and self.finalize_executor.is_idle()
)
def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool:
deadline = time.perf_counter() + max(0.0, timeout_sec)
while time.perf_counter() < deadline:
if self.is_drained():
return True
time.sleep(poll_interval_sec)
return self.is_drained()
def submit(
self,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None = None,
done_future: asyncio.Future | None = None,
engine_request_id: str | None = None,
timeout_sec: float | None = None,
skip_capacity_wait: bool = False,
admission_wait_ms_override: float | None = None,
admission_snapshot_override: Dict[str, Any] | None = None,
engine_policy_wait_ms: float = 0.0,
engine_dispatch_wait_ms: float = 0.0,
enqueue_pending: bool = True,
) -> SchedulerPendingJob:
if skip_capacity_wait:
with self.condition:
admission_snapshot = (
dict(admission_snapshot_override)
if admission_snapshot_override is not None
else dict(self._admission_snapshot_locked())
)
admission_wait_ms = 0.0 if admission_wait_ms_override is None else float(admission_wait_ms_override)
else:
admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec)
job = SchedulerPendingJob(
request_id=state.request_id,
state=state,
done_event=threading.Event(),
done_loop=done_loop,
done_future=done_future,
enqueue_time=time.perf_counter(),
speed_factor=float(speed_factor),
sample_steps=int(sample_steps),
media_type=media_type,
admission_wait_ms=float(admission_wait_ms),
engine_policy_wait_ms=float(engine_policy_wait_ms),
engine_dispatch_wait_ms=float(engine_dispatch_wait_ms),
prepare_wall_ms=float(prepare_wall_ms),
prepare_profile_total_ms=float(prepare_profile_total_ms),
engine_request_id=engine_request_id or state.request_id,
)
with self.condition:
self.job_registry.register(job, keep_job=not self.engine_decode_control_enabled)
if enqueue_pending:
self.decode_legacy_shell.enqueue_pending_job_locked(job)
self.condition.notify_all()
if enqueue_pending:
self._notify_decode_runtime_state("submit")
self._runtime_update(
job.engine_request_id,
EngineStatus.QUEUED,
{
"scheduler_request_id": job.request_id,
"decode_admission_wait_ms": float(admission_wait_ms),
"engine_policy_wait_ms": float(engine_policy_wait_ms),
"engine_dispatch_wait_ms": float(engine_dispatch_wait_ms),
"admission_snapshot": dict(admission_snapshot),
},
)
return job
async def prepare_state_profiled_async(
self,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
) -> tuple[T2SRequestState, float, float]:
return await self.prepare_executor.prepare_state_profiled_async(spec, prepare_submit_at)
async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]:
return await self.prepare_executor.prepare_states_batch_async(specs)
async def prepare_cpu_stage_profiled_async(
self,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
) -> PreparedCpuStage:
return await self.prepare_executor.prepare_cpu_stage_profiled_async(spec, prepare_submit_at)
async def prepare_gpu_stage_profiled_async(
self,
cpu_stage: PreparedCpuStage,
) -> tuple[T2SRequestState, float, float]:
return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage)