mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-17 15:58:15 +08:00
Add unified engine stage components for TTS processing orchestration
Introduce new modules including EngineDecodeStageMixin, EngineDispatchStageMixin, EngineFinalizeStageMixin, EnginePrepareStageMixin, and EngineStageFutureMixin. These components enhance the TTS framework by providing structured methods for managing engine stages, including decoding, dispatching, finalizing, and preparing tasks. The new architecture supports improved state management and asynchronous operations, significantly enhancing the maintainability and performance of the TTS system.
This commit is contained in:
parent
a3a5aad157
commit
d453a8e47c
@ -280,4 +280,4 @@ class EngineApiSchedulerFlow:
|
||||
spec.request_id,
|
||||
dict(submit_profile, response_headers_emitted=True),
|
||||
)
|
||||
return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=f"audio/{job.media_type}", headers=headers)
|
||||
return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=str(job.media_type), headers=headers)
|
||||
|
||||
40
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_decode.py
Normal file
40
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_decode.py
Normal file
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class EngineDecodeStageMixin:
|
||||
def run_engine_decode_runtime_once(self) -> bool:
|
||||
if not self.scheduler_worker.is_engine_decode_control_enabled():
|
||||
return False
|
||||
runtime_state = self.snapshot_engine_decode_runtime_state()
|
||||
pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking(
|
||||
wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0
|
||||
)
|
||||
result = self.scheduler_worker.execute_decode_cycle(
|
||||
pending_jobs=pending_jobs,
|
||||
active_batch=self.decode_runtime_owner.get_active_batch(),
|
||||
external_bookkeeping=True,
|
||||
)
|
||||
prefill_phase = dict(result.get("prefill_phase") or {})
|
||||
if prefill_phase.get("error"):
|
||||
self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error")))
|
||||
else:
|
||||
prefill_jobs = list(prefill_phase.get("pending_jobs") or [])
|
||||
self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0)))
|
||||
self.add_engine_merge_time(
|
||||
[] if result.get("active_batch") is None else list(result["active_batch"].request_ids),
|
||||
float(prefill_phase.get("merge_elapsed_s", 0.0)),
|
||||
)
|
||||
self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or []))
|
||||
decode_phase = dict(result.get("decode_phase") or {})
|
||||
if decode_phase.get("error"):
|
||||
self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error")))
|
||||
else:
|
||||
self.add_engine_decode_time(
|
||||
list(decode_phase.get("request_ids") or []),
|
||||
float(decode_phase.get("decode_elapsed_s", 0.0)),
|
||||
)
|
||||
self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or []))
|
||||
self.decode_runtime_owner.set_active_batch(result.get("active_batch"))
|
||||
if result.get("executed", False):
|
||||
self.decode_runtime_owner.refresh_state("engine_decode_cycle")
|
||||
return bool(result.get("executed", False))
|
||||
93
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py
Normal file
93
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py
Normal file
@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask
|
||||
|
||||
|
||||
class EngineDispatchStageMixin:
|
||||
async def enqueue_prepared_state_for_dispatch(
|
||||
self,
|
||||
*,
|
||||
state: T2SRequestState,
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None,
|
||||
done_future: asyncio.Future | None,
|
||||
engine_request_id: str | None,
|
||||
timeout_sec: float | None,
|
||||
) -> EngineDispatchTask:
|
||||
task = EngineDispatchTask(
|
||||
request_id=state.request_id,
|
||||
state=state,
|
||||
speed_factor=float(speed_factor),
|
||||
sample_steps=int(sample_steps),
|
||||
media_type=media_type,
|
||||
prepare_wall_ms=float(prepare_wall_ms),
|
||||
prepare_profile_total_ms=float(prepare_profile_total_ms),
|
||||
done_loop=done_loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=engine_request_id or state.request_id,
|
||||
timeout_sec=timeout_sec,
|
||||
enqueue_time=time.perf_counter(),
|
||||
)
|
||||
self.dispatch_queue_owner.enqueue(task)
|
||||
self.notify_arbiter()
|
||||
self.merge_request_state_profile(
|
||||
task.engine_request_id or task.request_id,
|
||||
{
|
||||
"engine_dispatch_queue_depth_on_enqueue": int(
|
||||
self.snapshot_engine_dispatch_state()["waiting_count"]
|
||||
),
|
||||
},
|
||||
)
|
||||
return task
|
||||
|
||||
def run_engine_dispatch_once(self, policy_snapshot: Dict[str, object], worker_state: Dict[str, object]) -> bool:
|
||||
if not bool(policy_snapshot.get("allowed", True)):
|
||||
return False
|
||||
dispatch_task = self.dispatch_queue_owner.pop_left()
|
||||
if dispatch_task is None:
|
||||
return False
|
||||
dispatched_at = time.perf_counter()
|
||||
dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0)
|
||||
dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms)
|
||||
dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms)
|
||||
dispatch_task.engine_policy_snapshot = dict(policy_snapshot)
|
||||
try:
|
||||
worker_job = self.scheduler_worker.submit(
|
||||
state=dispatch_task.state,
|
||||
speed_factor=dispatch_task.speed_factor,
|
||||
sample_steps=dispatch_task.sample_steps,
|
||||
media_type=dispatch_task.media_type,
|
||||
prepare_wall_ms=dispatch_task.prepare_wall_ms,
|
||||
prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms,
|
||||
done_loop=dispatch_task.done_loop,
|
||||
done_future=dispatch_task.done_future,
|
||||
engine_request_id=dispatch_task.engine_request_id,
|
||||
timeout_sec=dispatch_task.timeout_sec,
|
||||
skip_capacity_wait=True,
|
||||
admission_wait_ms_override=0.0,
|
||||
admission_snapshot_override=dict(worker_state),
|
||||
engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms,
|
||||
engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms,
|
||||
enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(),
|
||||
)
|
||||
dispatch_task.worker_job = worker_job
|
||||
self.register_engine_job(worker_job)
|
||||
if self.scheduler_worker.is_engine_decode_control_enabled():
|
||||
self.decode_runtime_owner.enqueue_pending_job(worker_job)
|
||||
self.notify_arbiter()
|
||||
self.dispatch_queue_owner.mark_completed(1)
|
||||
return True
|
||||
except Exception as exc:
|
||||
dispatch_task.error = str(exc)
|
||||
self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc))
|
||||
self._notify_dispatch_error(dispatch_task, exc)
|
||||
return True
|
||||
@ -1,24 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import (
|
||||
EngineDecodeRuntimeOwner,
|
||||
EngineDispatchTask,
|
||||
EngineGpuPrepareTask,
|
||||
EngineStatus,
|
||||
EngineTaskQueueOwner,
|
||||
SchedulerFinalizeTask,
|
||||
SchedulerPendingJob,
|
||||
)
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_decode import EngineDecodeStageMixin
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_dispatch import EngineDispatchStageMixin
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_finalize import EngineFinalizeStageMixin
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_futures import EngineStageFutureMixin
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_prepare import EnginePrepareStageMixin
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
|
||||
|
||||
|
||||
class EngineStageExecutor:
|
||||
class EngineStageExecutor(
|
||||
EngineStageFutureMixin,
|
||||
EnginePrepareStageMixin,
|
||||
EngineFinalizeStageMixin,
|
||||
EngineDispatchStageMixin,
|
||||
EngineDecodeStageMixin,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -62,297 +68,3 @@ class EngineStageExecutor:
|
||||
self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state
|
||||
self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state
|
||||
self._notify_arbiter: Callable[[], None] | None = None
|
||||
|
||||
def bind_notify_arbiter(self, notify_arbiter: Callable[[], None]) -> None:
|
||||
self._notify_arbiter = notify_arbiter
|
||||
|
||||
def notify_arbiter(self) -> None:
|
||||
if self._notify_arbiter is not None:
|
||||
self._notify_arbiter()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None:
|
||||
if future.done():
|
||||
return
|
||||
future.set_exception(error)
|
||||
|
||||
def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None:
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _resolve_prepare_future(
|
||||
future: asyncio.Future,
|
||||
payload: tuple[T2SRequestState, float, float],
|
||||
) -> None:
|
||||
if future.done():
|
||||
return
|
||||
future.set_result(payload)
|
||||
|
||||
def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None:
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def _notify_prepare_result(
|
||||
self,
|
||||
task: EngineGpuPrepareTask,
|
||||
payload: tuple[T2SRequestState, float, float],
|
||||
) -> None:
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
async def prepare_state_via_engine_gpu_queue(
|
||||
self,
|
||||
*,
|
||||
spec: Any,
|
||||
prepare_submit_at: float,
|
||||
engine_request_id: str | None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at)
|
||||
if engine_request_id not in [None, ""]:
|
||||
self.update_request_state(
|
||||
str(engine_request_id),
|
||||
EngineStatus.GPU_PREPARING,
|
||||
{
|
||||
"prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms),
|
||||
"prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms),
|
||||
"text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms),
|
||||
"text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms),
|
||||
},
|
||||
)
|
||||
loop = asyncio.get_running_loop()
|
||||
done_future = loop.create_future()
|
||||
task = EngineGpuPrepareTask(
|
||||
request_id=spec.request_id,
|
||||
cpu_stage=cpu_stage,
|
||||
done_loop=loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=engine_request_id or spec.request_id,
|
||||
enqueue_time=time.perf_counter(),
|
||||
)
|
||||
self.prepare_queue_owner.enqueue(task)
|
||||
self.notify_arbiter()
|
||||
return await done_future
|
||||
|
||||
def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
|
||||
if not tasks:
|
||||
return
|
||||
for task in tasks:
|
||||
job = self.get_engine_job(task.request_id)
|
||||
if job is not None:
|
||||
self.update_request_state(
|
||||
job.engine_request_id,
|
||||
EngineStatus.READY_FOR_FINALIZE,
|
||||
{
|
||||
"finish_reason": task.item.finish_reason,
|
||||
"semantic_len": int(task.item.semantic_tokens.shape[0]),
|
||||
"finish_idx": int(task.item.finish_idx),
|
||||
},
|
||||
)
|
||||
self.finalize_queue_owner.enqueue_many(tasks)
|
||||
self.notify_arbiter()
|
||||
|
||||
def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
|
||||
finalize_policy = self.scheduler_worker.get_finalize_batch_policy()
|
||||
return self.finalize_queue_owner.take_finalize_batch(
|
||||
finalize_mode=str(finalize_policy.get("finalize_mode", "async")),
|
||||
batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)),
|
||||
batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)),
|
||||
use_vocoder=bool(self.tts.configs.use_vocoder),
|
||||
)
|
||||
|
||||
async def enqueue_prepared_state_for_dispatch(
|
||||
self,
|
||||
*,
|
||||
state: T2SRequestState,
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None,
|
||||
done_future: asyncio.Future | None,
|
||||
engine_request_id: str | None,
|
||||
timeout_sec: float | None,
|
||||
) -> EngineDispatchTask:
|
||||
task = EngineDispatchTask(
|
||||
request_id=state.request_id,
|
||||
state=state,
|
||||
speed_factor=float(speed_factor),
|
||||
sample_steps=int(sample_steps),
|
||||
media_type=media_type,
|
||||
prepare_wall_ms=float(prepare_wall_ms),
|
||||
prepare_profile_total_ms=float(prepare_profile_total_ms),
|
||||
done_loop=done_loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=engine_request_id or state.request_id,
|
||||
timeout_sec=timeout_sec,
|
||||
enqueue_time=time.perf_counter(),
|
||||
)
|
||||
self.dispatch_queue_owner.enqueue(task)
|
||||
self.notify_arbiter()
|
||||
self.merge_request_state_profile(
|
||||
task.engine_request_id or task.request_id,
|
||||
{
|
||||
"engine_dispatch_queue_depth_on_enqueue": int(
|
||||
self.snapshot_engine_dispatch_state()["waiting_count"]
|
||||
),
|
||||
},
|
||||
)
|
||||
return task
|
||||
|
||||
def run_engine_prepare_once(self) -> bool:
|
||||
task = self.prepare_queue_owner.pop_left()
|
||||
if task is None:
|
||||
return False
|
||||
queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0)
|
||||
try:
|
||||
state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run(
|
||||
self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage)
|
||||
)
|
||||
state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms)
|
||||
if task.engine_request_id not in [None, ""]:
|
||||
self.merge_request_state_profile(
|
||||
str(task.engine_request_id),
|
||||
{"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)},
|
||||
)
|
||||
self.prepare_queue_owner.mark_completed(1)
|
||||
self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at))
|
||||
return True
|
||||
except Exception as exc:
|
||||
task.error = str(exc)
|
||||
self.fail_request_state(task.engine_request_id or task.request_id, str(exc))
|
||||
self._notify_prepare_error(task, exc)
|
||||
return True
|
||||
|
||||
def run_engine_finalize_once(self) -> bool:
|
||||
tasks = self.take_engine_finalize_batch_nonblocking()
|
||||
if not tasks:
|
||||
return False
|
||||
self.scheduler_worker.begin_finalize_execution(len(tasks))
|
||||
try:
|
||||
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
|
||||
for task in tasks:
|
||||
job = self.get_engine_job(task.request_id)
|
||||
if job is None:
|
||||
continue
|
||||
jobs_and_items.append((job, task.item))
|
||||
if not jobs_and_items:
|
||||
return False
|
||||
now = time.perf_counter()
|
||||
for task in tasks:
|
||||
job = self.get_engine_job(task.request_id)
|
||||
if job is not None:
|
||||
job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0)
|
||||
for job, item in jobs_and_items:
|
||||
self.update_request_state(
|
||||
job.engine_request_id,
|
||||
EngineStatus.FINALIZING,
|
||||
{
|
||||
"finish_reason": item.finish_reason,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
},
|
||||
)
|
||||
synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items)
|
||||
for job, _ in jobs_and_items:
|
||||
job.synth_ms += float(synth_ms)
|
||||
for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results):
|
||||
self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
|
||||
except Exception as exc:
|
||||
self.fail_engine_jobs([task.request_id for task in tasks], str(exc))
|
||||
finally:
|
||||
self.scheduler_worker.end_finalize_execution(len(tasks))
|
||||
self.finalize_queue_owner.mark_completed(len(tasks), notify=True)
|
||||
return True
|
||||
|
||||
def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
|
||||
if not bool(policy_snapshot.get("allowed", True)):
|
||||
return False
|
||||
dispatch_task = self.dispatch_queue_owner.pop_left()
|
||||
if dispatch_task is None:
|
||||
return False
|
||||
dispatched_at = time.perf_counter()
|
||||
dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0)
|
||||
dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms)
|
||||
dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms)
|
||||
dispatch_task.engine_policy_snapshot = dict(policy_snapshot)
|
||||
try:
|
||||
worker_job = self.scheduler_worker.submit(
|
||||
state=dispatch_task.state,
|
||||
speed_factor=dispatch_task.speed_factor,
|
||||
sample_steps=dispatch_task.sample_steps,
|
||||
media_type=dispatch_task.media_type,
|
||||
prepare_wall_ms=dispatch_task.prepare_wall_ms,
|
||||
prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms,
|
||||
done_loop=dispatch_task.done_loop,
|
||||
done_future=dispatch_task.done_future,
|
||||
engine_request_id=dispatch_task.engine_request_id,
|
||||
timeout_sec=dispatch_task.timeout_sec,
|
||||
skip_capacity_wait=True,
|
||||
admission_wait_ms_override=0.0,
|
||||
admission_snapshot_override=dict(worker_state),
|
||||
engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms,
|
||||
engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms,
|
||||
enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(),
|
||||
)
|
||||
dispatch_task.worker_job = worker_job
|
||||
self.register_engine_job(worker_job)
|
||||
if self.scheduler_worker.is_engine_decode_control_enabled():
|
||||
self.decode_runtime_owner.enqueue_pending_job(worker_job)
|
||||
self.notify_arbiter()
|
||||
self.dispatch_queue_owner.mark_completed(1)
|
||||
return True
|
||||
except Exception as exc:
|
||||
dispatch_task.error = str(exc)
|
||||
self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc))
|
||||
self._notify_dispatch_error(dispatch_task, exc)
|
||||
return True
|
||||
|
||||
def run_engine_decode_runtime_once(self) -> bool:
|
||||
if not self.scheduler_worker.is_engine_decode_control_enabled():
|
||||
return False
|
||||
runtime_state = self.snapshot_engine_decode_runtime_state()
|
||||
pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking(
|
||||
wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0
|
||||
)
|
||||
result = self.scheduler_worker.execute_decode_cycle(
|
||||
pending_jobs=pending_jobs,
|
||||
active_batch=self.decode_runtime_owner.get_active_batch(),
|
||||
external_bookkeeping=True,
|
||||
)
|
||||
prefill_phase = dict(result.get("prefill_phase") or {})
|
||||
if prefill_phase.get("error"):
|
||||
self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error")))
|
||||
else:
|
||||
prefill_jobs = list(prefill_phase.get("pending_jobs") or [])
|
||||
self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0)))
|
||||
self.add_engine_merge_time(
|
||||
[] if result.get("active_batch") is None else list(result["active_batch"].request_ids),
|
||||
float(prefill_phase.get("merge_elapsed_s", 0.0)),
|
||||
)
|
||||
self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or []))
|
||||
decode_phase = dict(result.get("decode_phase") or {})
|
||||
if decode_phase.get("error"):
|
||||
self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error")))
|
||||
else:
|
||||
self.add_engine_decode_time(
|
||||
list(decode_phase.get("request_ids") or []),
|
||||
float(decode_phase.get("decode_elapsed_s", 0.0)),
|
||||
)
|
||||
self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or []))
|
||||
self.decode_runtime_owner.set_active_batch(result.get("active_batch"))
|
||||
if result.get("executed", False):
|
||||
self.decode_runtime_owner.refresh_state("engine_decode_cycle")
|
||||
return bool(result.get("executed", False))
|
||||
|
||||
76
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py
Normal file
76
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py
Normal file
@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob
|
||||
|
||||
|
||||
class EngineFinalizeStageMixin:
|
||||
def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
|
||||
if not tasks:
|
||||
return
|
||||
for task in tasks:
|
||||
job = self.get_engine_job(task.request_id)
|
||||
if job is not None:
|
||||
self.update_request_state(
|
||||
job.engine_request_id,
|
||||
EngineStatus.READY_FOR_FINALIZE,
|
||||
{
|
||||
"finish_reason": task.item.finish_reason,
|
||||
"semantic_len": int(task.item.semantic_tokens.shape[0]),
|
||||
"finish_idx": int(task.item.finish_idx),
|
||||
},
|
||||
)
|
||||
self.finalize_queue_owner.enqueue_many(tasks)
|
||||
self.notify_arbiter()
|
||||
|
||||
def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
|
||||
finalize_policy = self.scheduler_worker.get_finalize_batch_policy()
|
||||
return self.finalize_queue_owner.take_finalize_batch(
|
||||
finalize_mode=str(finalize_policy.get("finalize_mode", "async")),
|
||||
batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)),
|
||||
batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)),
|
||||
use_vocoder=bool(self.tts.configs.use_vocoder),
|
||||
)
|
||||
|
||||
def run_engine_finalize_once(self) -> bool:
|
||||
tasks = self.take_engine_finalize_batch_nonblocking()
|
||||
if not tasks:
|
||||
return False
|
||||
self.scheduler_worker.begin_finalize_execution(len(tasks))
|
||||
try:
|
||||
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
|
||||
for task in tasks:
|
||||
job = self.get_engine_job(task.request_id)
|
||||
if job is None:
|
||||
continue
|
||||
jobs_and_items.append((job, task.item))
|
||||
if not jobs_and_items:
|
||||
return False
|
||||
now = time.perf_counter()
|
||||
for task in tasks:
|
||||
job = self.get_engine_job(task.request_id)
|
||||
if job is not None:
|
||||
job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0)
|
||||
for job, item in jobs_and_items:
|
||||
self.update_request_state(
|
||||
job.engine_request_id,
|
||||
EngineStatus.FINALIZING,
|
||||
{
|
||||
"finish_reason": item.finish_reason,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
},
|
||||
)
|
||||
synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items)
|
||||
for job, _ in jobs_and_items:
|
||||
job.synth_ms += float(synth_ms)
|
||||
for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results):
|
||||
self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
|
||||
except Exception as exc:
|
||||
self.fail_engine_jobs([task.request_id for task in tasks], str(exc))
|
||||
finally:
|
||||
self.scheduler_worker.end_finalize_execution(len(tasks))
|
||||
self.finalize_queue_owner.mark_completed(len(tasks), notify=True)
|
||||
return True
|
||||
59
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_futures.py
Normal file
59
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_futures.py
Normal file
@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Callable
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, EngineGpuPrepareTask
|
||||
|
||||
|
||||
class EngineStageFutureMixin:
|
||||
def bind_notify_arbiter(self, notify_arbiter: Callable[[], None]) -> None:
|
||||
self._notify_arbiter = notify_arbiter
|
||||
|
||||
def notify_arbiter(self) -> None:
|
||||
if self._notify_arbiter is not None:
|
||||
self._notify_arbiter()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None:
|
||||
if future.done():
|
||||
return
|
||||
future.set_exception(error)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_prepare_future(
|
||||
future: asyncio.Future,
|
||||
payload: tuple[T2SRequestState, float, float],
|
||||
) -> None:
|
||||
if future.done():
|
||||
return
|
||||
future.set_result(payload)
|
||||
|
||||
def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None:
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None:
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def _notify_prepare_result(
|
||||
self,
|
||||
task: EngineGpuPrepareTask,
|
||||
payload: tuple[T2SRequestState, float, float],
|
||||
) -> None:
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload)
|
||||
except RuntimeError:
|
||||
pass
|
||||
67
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py
Normal file
67
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py
Normal file
@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineGpuPrepareTask, EngineStatus
|
||||
|
||||
|
||||
class EnginePrepareStageMixin:
|
||||
async def prepare_state_via_engine_gpu_queue(
|
||||
self,
|
||||
*,
|
||||
spec: Any,
|
||||
prepare_submit_at: float,
|
||||
engine_request_id: str | None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at)
|
||||
if engine_request_id not in [None, ""]:
|
||||
self.update_request_state(
|
||||
str(engine_request_id),
|
||||
EngineStatus.GPU_PREPARING,
|
||||
{
|
||||
"prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms),
|
||||
"prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms),
|
||||
"text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms),
|
||||
"text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms),
|
||||
},
|
||||
)
|
||||
loop = asyncio.get_running_loop()
|
||||
done_future = loop.create_future()
|
||||
task = EngineGpuPrepareTask(
|
||||
request_id=spec.request_id,
|
||||
cpu_stage=cpu_stage,
|
||||
done_loop=loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=engine_request_id or spec.request_id,
|
||||
enqueue_time=time.perf_counter(),
|
||||
)
|
||||
self.prepare_queue_owner.enqueue(task)
|
||||
self.notify_arbiter()
|
||||
return await done_future
|
||||
|
||||
def run_engine_prepare_once(self) -> bool:
|
||||
task = self.prepare_queue_owner.pop_left()
|
||||
if task is None:
|
||||
return False
|
||||
queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0)
|
||||
try:
|
||||
state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run(
|
||||
self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage)
|
||||
)
|
||||
state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms)
|
||||
if task.engine_request_id not in [None, ""]:
|
||||
self.merge_request_state_profile(
|
||||
str(task.engine_request_id),
|
||||
{"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)},
|
||||
)
|
||||
self.prepare_queue_owner.mark_completed(1)
|
||||
self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at))
|
||||
return True
|
||||
except Exception as exc:
|
||||
task.error = str(exc)
|
||||
self.fail_request_state(task.engine_request_id or task.request_id, str(exc))
|
||||
self._notify_prepare_error(task, exc)
|
||||
return True
|
||||
Loading…
x
Reference in New Issue
Block a user