Enhance TTS processing with new reference specification handling and profiling metrics

Refactor the PrepareCoordinator and related components to improve the handling of reference specifications in the TTS system. Introduce new methods for building and extracting reference prompts and specifications, along with detailed profiling metrics for performance monitoring. Update the PrepareRefSemanticBatchWorker to include additional timing metrics and caching mechanisms for resampling. These changes enhance the efficiency and maintainability of the TTS framework, particularly in managing audio processing and reference data.
This commit is contained in:
baicai-1145 2026-03-13 16:45:38 +08:00
parent c94de2f2cb
commit 8a444c10b7
18 changed files with 1006 additions and 174 deletions

View File

@ -996,21 +996,39 @@ class TTS:
return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr)
def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
spec, audio, _, _, _ = self._extract_ref_spec_profile_from_raw(raw_audio, raw_sr)
return spec, audio, raw_audio, raw_sr
def _extract_ref_spec_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
profile = {
"ref_spec_to_device_ms": 0.0,
"ref_spec_main_resample_ms": 0.0,
"ref_spec_norm_ms": 0.0,
"ref_spec_spectrogram_ms": 0.0,
"ref_spec_post_resample_ms": 0.0,
}
to_device_start = time.perf_counter()
raw_audio_device = raw_audio.to(self.configs.device).float()
profile["ref_spec_to_device_ms"] = (time.perf_counter() - to_device_start) * 1000.0
if raw_sr != self.configs.sampling_rate:
resample_start = time.perf_counter()
audio = raw_audio_device
if audio.shape[0] == 2:
audio = audio.mean(0).unsqueeze(0)
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
profile["ref_spec_main_resample_ms"] = (time.perf_counter() - resample_start) * 1000.0
else:
audio = raw_audio_device
if audio.shape[0] == 2:
audio = audio.mean(0).unsqueeze(0)
norm_start = time.perf_counter()
maxx = audio.abs().max()
if maxx > 1:
audio /= min(2, maxx)
profile["ref_spec_norm_ms"] = (time.perf_counter() - norm_start) * 1000.0
spec_start = time.perf_counter()
spec = spectrogram_torch(
audio,
self.configs.filter_length,
@ -1019,15 +1037,18 @@ class TTS:
self.configs.win_length,
center=False,
)
profile["ref_spec_spectrogram_ms"] = (time.perf_counter() - spec_start) * 1000.0
if self.configs.is_half:
spec = spec.half()
if self.is_v2pro == True:
post_resample_start = time.perf_counter()
audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device)
profile["ref_spec_post_resample_ms"] = (time.perf_counter() - post_resample_start) * 1000.0
if self.configs.is_half:
audio = audio.half()
else:
audio = None
return spec, audio, raw_audio, raw_sr
return spec, audio, raw_audio, raw_sr, profile
def extract_ref_spec(self, ref_audio_path: str):
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)

View File

@ -228,8 +228,74 @@ class PrepareCoordinator:
def _load_ref_audio_raw(self, ref_audio_path: str):
return self.tts._load_ref_audio_raw(ref_audio_path)
def _build_ref_prompt_semantic_from_raw(self, raw_audio, raw_sr: int):
load_profile = {"audio_load_ms": 0.0}
if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None:
prompt_semantic, worker_profile = self.tts.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr)
return {
"prompt_semantic": prompt_semantic,
"raw_audio": raw_audio,
"raw_sr": raw_sr,
"profile": {
**load_profile,
"audio_stage_wait_ms": float(worker_profile.get("prompt_semantic_wait_ms", 0.0)),
"audio_stage_slots": float(worker_profile.get("prompt_semantic_stage_slots", 0.0)),
"audio_stage_inflight_peak": float(worker_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
"prompt_semantic_ms": float(
worker_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
+ worker_profile.get("prompt_semantic_forward_ms", 0.0)
+ worker_profile.get("prompt_semantic_scatter_ms", 0.0)
),
**{key: float(value) for key, value in worker_profile.items()},
"ref_spec_wait_ms": 0.0,
"ref_spec_ms": 0.0,
"bundle_total_ms": float(worker_profile.get("prompt_semantic_wait_ms", 0.0))
+ float(worker_profile.get("prompt_semantic_cpu_prepare_ms", 0.0))
+ float(worker_profile.get("prompt_semantic_forward_ms", 0.0))
+ float(worker_profile.get("prompt_semantic_scatter_ms", 0.0)),
},
}
wav16k, cpu_prepare_ms, limiter_stats = self.tts._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr)
with self.tts.prepare_ref_audio_stage_limiter.enter() as stage_stats:
prompt_semantic, forward_ms = self.tts._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k)
return {
"prompt_semantic": prompt_semantic,
"raw_audio": raw_audio,
"raw_sr": raw_sr,
"profile": {
"audio_load_ms": 0.0,
"audio_stage_wait_ms": float(stage_stats.get("wait_ms", 0.0)),
"audio_stage_slots": float(stage_stats.get("slots", 0.0)),
"audio_stage_inflight_peak": float(stage_stats.get("peak_inflight", 0.0)),
"prompt_semantic_wait_ms": float(stage_stats.get("wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_wait_ms": float(limiter_stats.get("wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_slots": float(limiter_stats.get("slots", 0.0)),
"prompt_semantic_cpu_prepare_inflight_peak": float(limiter_stats.get("peak_inflight", 0.0)),
"prompt_semantic_worker_queue_wait_ms": 0.0,
"prompt_semantic_batch_collect_wait_ms": 0.0,
"prompt_semantic_stage_limiter_wait_ms": float(stage_stats.get("wait_ms", 0.0)),
"prompt_semantic_batch_dispatch_delay_ms": 0.0,
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
"prompt_semantic_pack_ms": 0.0,
"prompt_semantic_h2d_ms": 0.0,
"prompt_semantic_ssl_forward_ms": 0.0,
"prompt_semantic_hidden_length_ms": 0.0,
"prompt_semantic_extract_latent_ms": 0.0,
"prompt_semantic_forward_ms": float(forward_ms),
"prompt_semantic_scatter_ms": 0.0,
"prompt_semantic_stage_slots": float(stage_stats.get("slots", 0.0)),
"prompt_semantic_stage_inflight_peak": float(stage_stats.get("peak_inflight", 0.0)),
"prompt_semantic_batch_size": 1.0,
"prompt_semantic_batch_samples": 0.0,
"ref_spec_wait_ms": 0.0,
"ref_spec_ms": 0.0,
"bundle_total_ms": float(cpu_prepare_ms + forward_ms + stage_stats.get("wait_ms", 0.0)),
},
}
def _extract_ref_spec_from_raw(self, raw_audio, raw_sr: int):
return self.tts._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
spec, audio, _, _, profile = self.tts._extract_ref_spec_profile_from_raw(raw_audio, raw_sr)
return (spec, audio), profile
@staticmethod
def _build_empty_text_features_like(reference: PreparedTextFeatures | None = None) -> PreparedTextFeatures:
@ -523,7 +589,7 @@ class PrepareCoordinator:
finally:
self.text_feature_gate.release()
async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult:
async def _run_ref_prompt_semantic_stage(self, ref_audio_path: str) -> ProfiledResult:
if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None:
submit_at = time.perf_counter()
started_at = float(submit_at)
@ -538,19 +604,7 @@ class PrepareCoordinator:
prompt_semantic_task = asyncio.create_task(
self.tts.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr)
)
await self.ref_spec_gate.acquire()
try:
ref_spec_task = asyncio.create_task(
self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr)
)
(prompt_semantic, prompt_semantic_profile), ref_spec_profiled = await asyncio.gather(
prompt_semantic_task,
ref_spec_task,
)
finally:
self.ref_spec_gate.release()
refer_spec = ref_spec_profiled.result
prompt_semantic, prompt_semantic_profile = await prompt_semantic_task
limiter_snapshot = (
self.tts.prepare_ref_audio_stage_limiter.snapshot()
if getattr(self.tts, "prepare_ref_audio_stage_limiter", None) is not None
@ -561,21 +615,15 @@ class PrepareCoordinator:
+ float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0))
+ float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0))
)
audio_stage_wait_ms = (
float(load_profiled.queue_ms)
+ float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0))
+ float(ref_spec_profiled.queue_ms)
)
finished_at = time.perf_counter()
result = {
"prompt_semantic": prompt_semantic,
"refer_spec": refer_spec,
"raw_audio": raw_audio,
"raw_sr": raw_sr,
"profile": {
"audio_load_queue_ms": float(load_profiled.queue_ms),
"audio_load_ms": float(load_profiled.run_ms),
"audio_stage_wait_ms": float(audio_stage_wait_ms),
"audio_stage_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"audio_stage_slots": float(
max(
float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
@ -593,6 +641,17 @@ class PrepareCoordinator:
"prompt_semantic_cpu_prepare_ms": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
),
"prompt_semantic_pack_ms": float(prompt_semantic_profile.get("prompt_semantic_pack_ms", 0.0)),
"prompt_semantic_h2d_ms": float(prompt_semantic_profile.get("prompt_semantic_h2d_ms", 0.0)),
"prompt_semantic_ssl_forward_ms": float(
prompt_semantic_profile.get("prompt_semantic_ssl_forward_ms", 0.0)
),
"prompt_semantic_hidden_length_ms": float(
prompt_semantic_profile.get("prompt_semantic_hidden_length_ms", 0.0)
),
"prompt_semantic_extract_latent_ms": float(
prompt_semantic_profile.get("prompt_semantic_extract_latent_ms", 0.0)
),
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),
"prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
@ -603,14 +662,10 @@ class PrepareCoordinator:
"prompt_semantic_batch_samples": float(
prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)
),
"ref_spec_wait_ms": float(ref_spec_profiled.queue_ms),
"ref_spec_ms": float(ref_spec_profiled.run_ms),
"bundle_total_ms": float(
load_profiled.queue_ms
+ load_profiled.run_ms
+ prompt_semantic_ms
+ ref_spec_profiled.queue_ms
+ ref_spec_profiled.run_ms
),
},
}
@ -623,21 +678,26 @@ class PrepareCoordinator:
await self.ref_audio_gate.acquire()
try:
if hasattr(self.tts, "extract_ref_audio_bundle_async"):
submit_at = time.perf_counter()
started_at = time.perf_counter()
result = await self.tts.extract_ref_audio_bundle_async(ref_audio_path)
finished_at = time.perf_counter()
return ProfiledResult(
result=result,
submit_at=float(submit_at),
started_at=float(started_at),
finished_at=float(finished_at),
)
return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path)
load_profiled = await self._run_on_executor(self.ref_audio_executor, self._load_ref_audio_raw, ref_audio_path)
raw_audio, raw_sr = load_profiled.result
submit_at = time.perf_counter()
started_at = time.perf_counter()
result = await asyncio.to_thread(self._build_ref_prompt_semantic_from_raw, raw_audio, raw_sr)
result.setdefault("profile", {})
result["profile"]["audio_load_queue_ms"] = float(load_profiled.queue_ms)
result["profile"]["audio_load_ms"] = float(load_profiled.run_ms)
finished_at = time.perf_counter()
return ProfiledResult(result=result, submit_at=float(submit_at), started_at=float(started_at), finished_at=float(finished_at))
finally:
self.ref_audio_gate.release()
async def _run_ref_spec_stage(self, raw_audio, raw_sr: int) -> ProfiledResult:
await self.ref_spec_gate.acquire()
try:
return await self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr)
finally:
self.ref_spec_gate.release()
def _release_split_stage_slot(self) -> None:
self._mark_leave()
self._inflight_gate.release()
@ -682,101 +742,318 @@ class PrepareCoordinator:
cpu_stage: PreparedCpuStage,
) -> tuple[T2SRequestState, float, float]:
try:
g2pw_pair_start = time.perf_counter()
g2pw_pair_task = asyncio.create_task(
self._run_g2pw_pair_stage(
cpu_stage.prompt_cpu_profiled.result,
cpu_stage.target_cpu_profiled.result,
)
)
ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(cpu_stage.spec.ref_audio_path)))
(prompt_g2pw_profiled, target_g2pw_profiled), ref_audio_profiled = await asyncio.gather(
g2pw_pair_task,
ref_audio_task,
)
g2pw_pair_end = time.perf_counter()
text_pair_start = time.perf_counter()
text_feature_pair_task = asyncio.create_task(
self._run_text_feature_pair_stage(
prompt_g2pw_profiled.result,
target_g2pw_profiled.result,
cpu_stage.prompt_cpu_profiled.run_ms,
cpu_stage.target_cpu_profiled.run_ms,
prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}),
target_base_profile=dict(target_g2pw_profiled.profile or {}),
)
)
prompt_feature_profiled, target_feature_profiled = await text_feature_pair_task
text_pair_end = time.perf_counter()
state = build_request_state_from_parts(
tts=self.tts,
spec=cpu_stage.spec,
prompt_text=cpu_stage.prompt_text,
text=cpu_stage.text,
prompt_result=prompt_feature_profiled.result,
target_result=target_feature_profiled.result,
ref_audio_bundle=ref_audio_profiled.result,
prepare_start=cpu_stage.prepare_start,
prepare_sync_start=cpu_stage.prepare_start,
profile_overrides={
"executor_queue_ms": max(0.0, (cpu_stage.prepare_start - cpu_stage.prepare_submit_at) * 1000.0),
"prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms,
"executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0),
"text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0),
"g2pw_pair_ms": max(0.0, (g2pw_pair_end - g2pw_pair_start) * 1000.0),
"prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms,
"prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms,
"prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)),
"prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)),
"prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)),
"text_g2pw_queue_ms": target_g2pw_profiled.queue_ms,
"text_g2pw_run_ms": target_g2pw_profiled.run_ms,
"text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)),
"text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)),
"text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)),
"prompt_text_parallel_future_wait_ms": 0.0,
"prompt_text_parallel_future_executor_queue_ms": 0.0,
"prompt_text_parallel_future_run_ms": 0.0,
"prompt_text_parallel_future_finish_after_submit_ms": 0.0,
"prompt_text_parallel_future_queue_tail_after_target_ms": 0.0,
"prompt_text_parallel_future_run_tail_after_target_ms": 0.0,
"prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms,
"prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms,
"prompt_text_cpu_admission_wait_ms": float(
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0)
),
"prompt_text_cpu_backpressure_wait_ms": float(
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0)
),
"prompt_text_cpu_capacity_wait_ms": float(
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0)
),
"prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms,
"prompt_text_feature_run_ms": prompt_feature_profiled.run_ms,
"text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms,
"text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms,
"text_cpu_admission_wait_ms": float(
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0)
),
"text_cpu_backpressure_wait_ms": float(
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0)
),
"text_cpu_capacity_wait_ms": float(
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0)
),
"text_feature_queue_ms": target_feature_profiled.queue_ms,
"text_feature_run_ms": target_feature_profiled.run_ms,
"ref_audio_task_queue_ms": ref_audio_profiled.queue_ms,
"ref_audio_task_run_ms": ref_audio_profiled.run_ms,
"worker_prepare_inflight_on_enter": float(cpu_stage.current_inflight),
"worker_prepare_peak_inflight": float(cpu_stage.peak_inflight),
phase_one = await self._prepare_gpu_phase_one(cpu_stage)
phase_two = await self._prepare_gpu_phase_two(cpu_stage, phase_one)
return self._build_gpu_prepare_result(
cpu_stage,
phase_one,
phase_two,
extra_profile={
"engine_prepare_audio_phase_mode": 0.0,
"engine_prepare_audio_phase_wall_ms": float(phase_one["phase_wall_ms"]),
"engine_prepare_audio_phase_batch_size": 1.0,
"engine_prepare_text_phase_wall_ms": float(phase_two["phase_wall_ms"]),
"engine_prepare_text_phase_batch_size": 1.0,
},
)
prepare_exec_finished_at = time.perf_counter()
state.prepare_profile["executor_run_wall_ms"] = max(
0.0, (prepare_exec_finished_at - cpu_stage.prepare_start) * 1000.0
finally:
self._release_split_stage_slot()
async def _prepare_gpu_phase_one(self, cpu_stage: PreparedCpuStage) -> Dict[str, Any]:
phase_start = time.perf_counter()
g2pw_pair_task = asyncio.create_task(
self._run_g2pw_pair_stage(
cpu_stage.prompt_cpu_profiled.result,
cpu_stage.target_cpu_profiled.result,
)
return state, cpu_stage.prepare_start, prepare_exec_finished_at
)
ref_audio_task = asyncio.create_task(self._run_ref_prompt_semantic_stage(str(cpu_stage.spec.ref_audio_path)))
prompt_g2pw_profiled, target_g2pw_profiled = await g2pw_pair_task
g2pw_pair_end = time.perf_counter()
ref_audio_profiled = await ref_audio_task
phase_end = time.perf_counter()
return {
"prompt_g2pw_profiled": prompt_g2pw_profiled,
"target_g2pw_profiled": target_g2pw_profiled,
"ref_audio_profiled": ref_audio_profiled,
"ref_spec_result": None,
"g2pw_pair_ms": max(0.0, (g2pw_pair_end - phase_start) * 1000.0),
"phase_wall_ms": max(0.0, (phase_end - phase_start) * 1000.0),
}
async def _prepare_gpu_phase_two(
self,
cpu_stage: PreparedCpuStage,
phase_one: Dict[str, Any],
) -> Dict[str, Any]:
phase_start = time.perf_counter()
prompt_g2pw_profiled = phase_one["prompt_g2pw_profiled"]
target_g2pw_profiled = phase_one["target_g2pw_profiled"]
prompt_feature_profiled, target_feature_profiled = await self._run_text_feature_pair_stage(
prompt_g2pw_profiled.result,
target_g2pw_profiled.result,
cpu_stage.prompt_cpu_profiled.run_ms,
cpu_stage.target_cpu_profiled.run_ms,
prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}),
target_base_profile=dict(target_g2pw_profiled.profile or {}),
)
phase_end = time.perf_counter()
return {
"prompt_feature_profiled": prompt_feature_profiled,
"target_feature_profiled": target_feature_profiled,
"phase_wall_ms": max(0.0, (phase_end - phase_start) * 1000.0),
}
def _build_gpu_prepare_result(
self,
cpu_stage: PreparedCpuStage,
phase_one: Dict[str, Any],
phase_two: Dict[str, Any],
extra_profile: Dict[str, float] | None = None,
) -> tuple[T2SRequestState, float, float]:
prompt_g2pw_profiled = phase_one["prompt_g2pw_profiled"]
target_g2pw_profiled = phase_one["target_g2pw_profiled"]
ref_audio_profiled = phase_one["ref_audio_profiled"]
ref_spec_result = phase_one.get("ref_spec_result")
prompt_feature_profiled = phase_two["prompt_feature_profiled"]
target_feature_profiled = phase_two["target_feature_profiled"]
profile_overrides = {
"executor_queue_ms": max(0.0, (cpu_stage.prepare_start - cpu_stage.prepare_submit_at) * 1000.0),
"prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms,
"prepare_submit_ts": float(cpu_stage.prepare_submit_at),
"prepare_cpu_start_ts": float(cpu_stage.prepare_start),
"prepare_cpu_done_ts": float(
max(cpu_stage.prompt_cpu_profiled.finished_at, cpu_stage.target_cpu_profiled.finished_at)
),
"prompt_text_cpu_start_ts": float(cpu_stage.prompt_cpu_profiled.started_at),
"prompt_text_cpu_end_ts": float(cpu_stage.prompt_cpu_profiled.finished_at),
"text_cpu_start_ts": float(cpu_stage.target_cpu_profiled.started_at),
"text_cpu_end_ts": float(cpu_stage.target_cpu_profiled.finished_at),
"executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0),
"text_feature_pair_ms": float(phase_two["phase_wall_ms"]),
"g2pw_pair_ms": float(phase_one["g2pw_pair_ms"]),
"prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms,
"prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms,
"prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)),
"prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)),
"prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)),
"text_g2pw_queue_ms": target_g2pw_profiled.queue_ms,
"text_g2pw_run_ms": target_g2pw_profiled.run_ms,
"text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)),
"text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)),
"text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)),
"prompt_text_parallel_future_wait_ms": 0.0,
"prompt_text_parallel_future_executor_queue_ms": 0.0,
"prompt_text_parallel_future_run_ms": 0.0,
"prompt_text_parallel_future_finish_after_submit_ms": 0.0,
"prompt_text_parallel_future_queue_tail_after_target_ms": 0.0,
"prompt_text_parallel_future_run_tail_after_target_ms": 0.0,
"prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms,
"prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms,
"prompt_text_cpu_admission_wait_ms": float(
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0)
),
"prompt_text_cpu_backpressure_wait_ms": float(
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0)
),
"prompt_text_cpu_capacity_wait_ms": float(
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0)
),
"prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms,
"prompt_text_feature_run_ms": prompt_feature_profiled.run_ms,
"text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms,
"text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms,
"text_cpu_admission_wait_ms": float(
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0)
),
"text_cpu_backpressure_wait_ms": float(
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0)
),
"text_cpu_capacity_wait_ms": float(
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0)
),
"text_feature_queue_ms": target_feature_profiled.queue_ms,
"text_feature_run_ms": target_feature_profiled.run_ms,
"ref_audio_task_queue_ms": ref_audio_profiled.queue_ms,
"ref_audio_task_run_ms": ref_audio_profiled.run_ms,
"worker_prepare_inflight_on_enter": float(cpu_stage.current_inflight),
"worker_prepare_peak_inflight": float(cpu_stage.peak_inflight),
}
if extra_profile:
profile_overrides.update({key: float(value) for key, value in extra_profile.items()})
ref_audio_bundle = dict(ref_audio_profiled.result)
ref_audio_profile = dict(ref_audio_bundle.get("profile", {}))
if ref_spec_result is not None:
refer_spec, ref_spec_profiled = ref_spec_result
ref_audio_bundle["refer_spec"] = refer_spec
ref_audio_profile.update(
{
"ref_spec_wait_ms": float(ref_spec_profiled.get("ref_spec_wait_ms", 0.0)),
"ref_spec_ms": float(ref_spec_profiled.get("ref_spec_ms", 0.0)),
"ref_spec_to_device_ms": float(ref_spec_profiled.get("ref_spec_to_device_ms", 0.0)),
"ref_spec_main_resample_ms": float(ref_spec_profiled.get("ref_spec_main_resample_ms", 0.0)),
"ref_spec_norm_ms": float(ref_spec_profiled.get("ref_spec_norm_ms", 0.0)),
"ref_spec_spectrogram_ms": float(ref_spec_profiled.get("ref_spec_spectrogram_ms", 0.0)),
"ref_spec_post_resample_ms": float(ref_spec_profiled.get("ref_spec_post_resample_ms", 0.0)),
}
)
else:
ref_audio_bundle["refer_spec"] = None
ref_audio_profile.setdefault("ref_spec_wait_ms", 0.0)
ref_audio_profile.setdefault("ref_spec_ms", 0.0)
ref_audio_profile.setdefault("ref_spec_to_device_ms", 0.0)
ref_audio_profile.setdefault("ref_spec_main_resample_ms", 0.0)
ref_audio_profile.setdefault("ref_spec_norm_ms", 0.0)
ref_audio_profile.setdefault("ref_spec_spectrogram_ms", 0.0)
ref_audio_profile.setdefault("ref_spec_post_resample_ms", 0.0)
ref_audio_bundle["profile"] = ref_audio_profile
state = build_request_state_from_parts(
tts=self.tts,
spec=cpu_stage.spec,
prompt_text=cpu_stage.prompt_text,
text=cpu_stage.text,
prompt_result=prompt_feature_profiled.result,
target_result=target_feature_profiled.result,
ref_audio_bundle=ref_audio_bundle,
prepare_start=cpu_stage.prepare_start,
prepare_sync_start=cpu_stage.prepare_start,
profile_overrides=profile_overrides,
)
prepare_exec_finished_at = time.perf_counter()
state.prepare_profile["executor_run_wall_ms"] = max(0.0, (prepare_exec_finished_at - cpu_stage.prepare_start) * 1000.0)
return state, cpu_stage.prepare_start, prepare_exec_finished_at
async def prepare_ref_spec_stages_async(
self,
phase_ones: list[Dict[str, Any]],
) -> list[tuple[tuple[Any, Any], Dict[str, float]] | Exception]:
async def _one(phase_one: Dict[str, Any]):
ref_audio_profiled = phase_one["ref_audio_profiled"]
raw_audio = ref_audio_profiled.result["raw_audio"]
raw_sr = int(ref_audio_profiled.result["raw_sr"])
profiled = await self._run_ref_spec_stage(raw_audio, raw_sr)
refer_spec, profile = profiled.result
merged_profile = dict(profile)
merged_profile["ref_spec_wait_ms"] = float(profiled.queue_ms)
merged_profile["ref_spec_ms"] = float(profiled.run_ms)
return refer_spec, merged_profile
if not phase_ones:
return []
return list(await asyncio.gather(*[_one(phase_one) for phase_one in phase_ones], return_exceptions=True))
def apply_ref_spec_result_to_state(
self,
state: T2SRequestState,
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]],
) -> None:
refer_spec, profile = ref_spec_result
state.refer_spec = refer_spec
state.prepare_profile["ref_spec_wait_ms"] = float(profile.get("ref_spec_wait_ms", 0.0))
state.prepare_profile["ref_spec_ms"] = float(profile.get("ref_spec_ms", 0.0))
state.prepare_profile["ref_spec_to_device_ms"] = float(profile.get("ref_spec_to_device_ms", 0.0))
state.prepare_profile["ref_spec_main_resample_ms"] = float(profile.get("ref_spec_main_resample_ms", 0.0))
state.prepare_profile["ref_spec_norm_ms"] = float(profile.get("ref_spec_norm_ms", 0.0))
state.prepare_profile["ref_spec_spectrogram_ms"] = float(profile.get("ref_spec_spectrogram_ms", 0.0))
state.prepare_profile["ref_spec_post_resample_ms"] = float(profile.get("ref_spec_post_resample_ms", 0.0))
async def prepare_gpu_stages_profiled_async(
self,
cpu_stages: list[PreparedCpuStage],
) -> list[tuple[T2SRequestState, float, float] | Exception]:
if not cpu_stages:
return []
if len(cpu_stages) == 1:
single_stage = cpu_stages[0]
try:
return [await self.prepare_gpu_stage_profiled_async(single_stage)]
except Exception as exc: # noqa: PERF203
return [exc]
phase_one_started_at = time.perf_counter()
phase_one_results = await asyncio.gather(
*[self._prepare_gpu_phase_one(cpu_stage) for cpu_stage in cpu_stages],
return_exceptions=True,
)
phase_one_finished_at = time.perf_counter()
phase_one_wall_ms = max(0.0, (phase_one_finished_at - phase_one_started_at) * 1000.0)
outputs: list[tuple[T2SRequestState, float, float] | Exception | None] = [None] * len(cpu_stages)
pending_phase_two: list[tuple[int, PreparedCpuStage, Dict[str, Any]]] = []
for index, (cpu_stage, phase_one) in enumerate(zip(cpu_stages, phase_one_results)):
if isinstance(phase_one, Exception):
outputs[index] = phase_one
self._release_split_stage_slot()
continue
pending_phase_two.append((index, cpu_stage, phase_one))
phase_two_started_at = time.perf_counter()
phase_two_results = await asyncio.gather(
*[self._prepare_gpu_phase_two(cpu_stage, phase_one) for _, cpu_stage, phase_one in pending_phase_two],
return_exceptions=True,
)
phase_two_finished_at = time.perf_counter()
phase_two_wall_ms = max(0.0, (phase_two_finished_at - phase_two_started_at) * 1000.0)
for (index, cpu_stage, phase_one), phase_two in zip(pending_phase_two, phase_two_results):
try:
if isinstance(phase_two, Exception):
outputs[index] = phase_two
continue
outputs[index] = self._build_gpu_prepare_result(
cpu_stage,
phase_one,
phase_two,
extra_profile={
"engine_prepare_audio_phase_mode": 1.0,
"engine_prepare_audio_phase_wall_ms": float(phase_one_wall_ms),
"engine_prepare_audio_phase_batch_size": float(len(cpu_stages)),
"engine_prepare_text_phase_wall_ms": float(phase_two_wall_ms),
"engine_prepare_text_phase_batch_size": float(len(pending_phase_two)),
},
)
except Exception as exc: # noqa: PERF203
outputs[index] = exc
finally:
self._release_split_stage_slot()
return [item if item is not None else RuntimeError("prepare batch result missing") for item in outputs]
async def prepare_gpu_audio_phases_async(
self,
cpu_stages: list[PreparedCpuStage],
) -> list[Dict[str, Any] | Exception]:
if not cpu_stages:
return []
return list(
await asyncio.gather(
*[self._prepare_gpu_phase_one(cpu_stage) for cpu_stage in cpu_stages],
return_exceptions=True,
)
)
async def prepare_gpu_text_phases_async(
self,
items: list[tuple[PreparedCpuStage, Dict[str, Any]]],
) -> list[Dict[str, Any] | Exception]:
if not items:
return []
return list(
await asyncio.gather(
*[self._prepare_gpu_phase_two(cpu_stage, phase_one) for cpu_stage, phase_one in items],
return_exceptions=True,
)
)
def build_gpu_prepare_result_from_phases(
self,
cpu_stage: PreparedCpuStage,
phase_one: Dict[str, Any],
phase_two: Dict[str, Any],
extra_profile: Dict[str, float] | None = None,
) -> tuple[T2SRequestState, float, float]:
try:
return self._build_gpu_prepare_result(cpu_stage, phase_one, phase_two, extra_profile=extra_profile)
finally:
self._release_split_stage_slot()

View File

@ -15,6 +15,7 @@ REF_AUDIO_MIN_SAMPLES_16K = 48000
REF_AUDIO_MAX_SAMPLES_16K = 160000
_RESAMPLE_CACHE_LOCK = threading.Lock()
_RESAMPLE_CACHE: Dict[Tuple[int, int, str], torchaudio.transforms.Resample] = {}
_RESAMPLE_STREAM_CACHE: Dict[str, torch.cuda.Stream] = {}
def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.transforms.Resample:
@ -28,6 +29,16 @@ def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.tran
return transform
def _get_resample_stream(device: str) -> torch.cuda.Stream:
device_key = str(device)
with _RESAMPLE_CACHE_LOCK:
stream = _RESAMPLE_STREAM_CACHE.get(device_key)
if stream is None:
stream = torch.cuda.Stream(device=device_key)
_RESAMPLE_STREAM_CACHE[device_key] = stream
return stream
def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor:
resample_device = os.environ.get("GPTSOVITS_PREPARE_REF_RESAMPLE_DEVICE", "cpu").strip().lower() or "cpu"
if resample_device not in {"cpu", "cuda"}:
@ -37,10 +48,20 @@ def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wa
wav_mono = raw_audio
if wav_mono.dim() == 2 and wav_mono.shape[0] != 1:
wav_mono = wav_mono.mean(0, keepdim=True)
wav16k = wav_mono.to(dtype=torch.float32, device=resample_device)
if raw_sr != 16000:
wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k)
wav16k = wav16k.squeeze(0).contiguous()
if resample_device == "cuda":
stream = _get_resample_stream(resample_device)
with torch.cuda.stream(stream):
wav16k = wav_mono.to(dtype=torch.float32, device=resample_device)
if raw_sr != 16000:
wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k)
wav16k = wav16k.squeeze(0).contiguous()
stream.synchronize()
wav16k = wav16k.detach().to(device="cpu", dtype=torch.float32).contiguous()
else:
wav16k = wav_mono.to(dtype=torch.float32, device=resample_device)
if raw_sr != 16000:
wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k)
wav16k = wav16k.squeeze(0).contiguous()
if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K:
raise OSError("参考音频在3~10秒范围外请更换")
if zero_wav_samples > 0:
@ -256,37 +277,56 @@ class PrepareRefSemanticBatchWorker:
batch_samples = int(wav_lengths.sum().item())
max_wav_len = int(wav_lengths.max().item())
pack_start = time.perf_counter()
input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32)
attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long)
for batch_index, wav in enumerate(prepared_wavs):
wav_len = int(wav.shape[0])
input_values_cpu[batch_index, :wav_len] = wav
attention_mask_cpu[batch_index, :wav_len] = 1
pack_ms = (time.perf_counter() - pack_start) * 1000.0
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
h2d_ms = 0.0
ssl_forward_ms = 0.0
hidden_length_ms = 0.0
extract_latent_ms = 0.0
if self.stage_limiter is None:
h2d_start = time.perf_counter()
input_values = input_values_cpu.to(self.device)
attention_mask = attention_mask_cpu.to(self.device)
if self.is_half:
input_values = input_values.half()
forward_start = time.perf_counter()
h2d_ms = (time.perf_counter() - h2d_start) * 1000.0
ssl_start = time.perf_counter()
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0
hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
hidden_length_start = time.perf_counter()
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0
latent_start = time.perf_counter()
codes = self.vits_model.extract_latent(hubert_feature)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0
else:
with self.stage_limiter.enter() as limiter_stats:
h2d_start = time.perf_counter()
input_values = input_values_cpu.to(self.device)
attention_mask = attention_mask_cpu.to(self.device)
if self.is_half:
input_values = input_values.half()
forward_start = time.perf_counter()
h2d_ms = (time.perf_counter() - h2d_start) * 1000.0
ssl_start = time.perf_counter()
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0
hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
hidden_length_start = time.perf_counter()
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0
latent_start = time.perf_counter()
codes = self.vits_model.extract_latent(hubert_feature)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0
forward_ms = float(h2d_ms + ssl_forward_ms + hidden_length_ms + extract_latent_ms)
code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None))
scatter_start = time.perf_counter()
@ -308,6 +348,11 @@ class PrepareRefSemanticBatchWorker:
0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0
),
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
"prompt_semantic_pack_ms": float(pack_ms),
"prompt_semantic_h2d_ms": float(h2d_ms),
"prompt_semantic_ssl_forward_ms": float(ssl_forward_ms),
"prompt_semantic_hidden_length_ms": float(hidden_length_ms),
"prompt_semantic_extract_latent_ms": float(extract_latent_ms),
"prompt_semantic_forward_ms": float(forward_ms),
"prompt_semantic_scatter_ms": 0.0,
"prompt_semantic_calls": 1.0,

View File

@ -55,7 +55,7 @@ class T2SRequestState:
all_phones: torch.LongTensor
all_bert_features: torch.Tensor
prompt_semantic: torch.LongTensor
refer_spec: Tuple[torch.Tensor, Optional[torch.Tensor]]
refer_spec: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]
aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
raw_audio: torch.Tensor
raw_sr: int
@ -188,7 +188,11 @@ def build_request_state_from_parts(
ref_audio_bundle_ms = float(ref_audio_bundle.get("profile", {}).get("bundle_total_ms", 0.0))
bundle_profile = ref_audio_bundle.get("profile", {})
prompt_semantic = ref_audio_bundle["prompt_semantic"].long()
spec_audio, audio_16k = ref_audio_bundle["refer_spec"]
refer_spec_value = ref_audio_bundle.get("refer_spec")
if refer_spec_value in [None, ()]:
spec_audio, audio_16k = None, None
else:
spec_audio, audio_16k = refer_spec_value
aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] = []
for aux_ref_audio_path in list(getattr(spec, "aux_ref_audio_paths", []) or []):
if aux_ref_audio_path in [None, ""]:
@ -323,6 +327,11 @@ def build_request_state_from_parts(
bundle_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
),
"prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
"prompt_semantic_pack_ms": float(bundle_profile.get("prompt_semantic_pack_ms", 0.0)),
"prompt_semantic_h2d_ms": float(bundle_profile.get("prompt_semantic_h2d_ms", 0.0)),
"prompt_semantic_ssl_forward_ms": float(bundle_profile.get("prompt_semantic_ssl_forward_ms", 0.0)),
"prompt_semantic_hidden_length_ms": float(bundle_profile.get("prompt_semantic_hidden_length_ms", 0.0)),
"prompt_semantic_extract_latent_ms": float(bundle_profile.get("prompt_semantic_extract_latent_ms", 0.0)),
"prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)),
"prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)),
"prompt_semantic_stage_slots": float(bundle_profile.get("prompt_semantic_stage_slots", 0.0)),
@ -331,6 +340,11 @@ def build_request_state_from_parts(
"prompt_semantic_batch_samples": float(bundle_profile.get("prompt_semantic_batch_samples", 0.0)),
"ref_spec_wait_ms": float(bundle_profile.get("ref_spec_wait_ms", 0.0)),
"ref_spec_ms": ref_spec_ms,
"ref_spec_to_device_ms": float(bundle_profile.get("ref_spec_to_device_ms", 0.0)),
"ref_spec_main_resample_ms": float(bundle_profile.get("ref_spec_main_resample_ms", 0.0)),
"ref_spec_norm_ms": float(bundle_profile.get("ref_spec_norm_ms", 0.0)),
"ref_spec_spectrogram_ms": float(bundle_profile.get("ref_spec_spectrogram_ms", 0.0)),
"ref_spec_post_resample_ms": float(bundle_profile.get("ref_spec_post_resample_ms", 0.0)),
"ref_audio_bundle_ms": ref_audio_bundle_ms,
"tensorize_ms": tensorize_ms,
"total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0,
@ -352,7 +366,7 @@ def build_request_state_from_parts(
all_phones=all_phones,
all_bert_features=all_bert_features,
prompt_semantic=prompt_semantic,
refer_spec=(spec_audio, audio_16k),
refer_spec=(None if spec_audio is None else (spec_audio, audio_16k)),
aux_refer_specs=aux_refer_specs,
raw_audio=raw_audio,
raw_sr=raw_sr,

View File

@ -21,6 +21,14 @@ class EngineRegistryBridgeFacade:
def engine_prepare_queue_owner(self):
return self.owner.engine_prepare_queue_owner
@property
def engine_prepare_text_queue_owner(self):
return self.owner.engine_prepare_text_queue_owner
@property
def engine_prepare_ref_spec_queue_owner(self):
return self.owner.engine_prepare_ref_spec_queue_owner
@property
def engine_finalize_queue_owner(self):
return self.owner.engine_finalize_queue_owner
@ -82,7 +90,33 @@ class EngineRegistryBridgeFacade:
return self.request_registry.snapshot()
def _snapshot_engine_prepare_state(self) -> Dict[str, Any]:
return self.engine_prepare_queue_owner.snapshot(max_request_ids=16)
audio_snapshot = self.engine_prepare_queue_owner.snapshot(max_request_ids=16)
text_snapshot = self.engine_prepare_text_queue_owner.snapshot(max_request_ids=16)
ref_spec_snapshot = self.engine_prepare_ref_spec_queue_owner.snapshot(max_request_ids=16)
return {
"waiting_count": int(audio_snapshot.get("waiting_count", 0))
+ int(text_snapshot.get("waiting_count", 0))
+ int(ref_spec_snapshot.get("waiting_count", 0)),
"audio_waiting_count": int(audio_snapshot.get("waiting_count", 0)),
"text_waiting_count": int(text_snapshot.get("waiting_count", 0)),
"ref_spec_waiting_count": int(ref_spec_snapshot.get("waiting_count", 0)),
"audio_waiting_request_ids": list(audio_snapshot.get("waiting_request_ids", [])),
"text_waiting_request_ids": list(text_snapshot.get("waiting_request_ids", [])),
"ref_spec_waiting_request_ids": list(ref_spec_snapshot.get("waiting_request_ids", [])),
"peak_waiting": int(
max(
int(audio_snapshot.get("peak_waiting", 0)),
int(text_snapshot.get("peak_waiting", 0)),
int(ref_spec_snapshot.get("peak_waiting", 0)),
)
),
"total_submitted": int(audio_snapshot.get("total_submitted", 0)),
"total_completed": int(audio_snapshot.get("total_completed", 0)),
"text_total_submitted": int(text_snapshot.get("total_submitted", 0)),
"text_total_completed": int(text_snapshot.get("total_completed", 0)),
"ref_spec_total_submitted": int(ref_spec_snapshot.get("total_submitted", 0)),
"ref_spec_total_completed": int(ref_spec_snapshot.get("total_completed", 0)),
}
def _snapshot_engine_finalize_state(self) -> Dict[str, Any]:
return self.engine_finalize_queue_owner.snapshot(max_request_ids=16)
@ -107,6 +141,8 @@ class EngineRegistryBridgeFacade:
def _is_engine_drained(self) -> bool:
prepare_empty = self.engine_prepare_queue_owner.is_drained()
prepare_text_empty = self.engine_prepare_text_queue_owner.is_drained()
prepare_ref_spec_empty = self.engine_prepare_ref_spec_queue_owner.is_drained()
dispatch_empty = self.engine_dispatch_queue_owner.is_drained()
finalize_empty = self.engine_finalize_queue_owner.is_drained()
decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs()
@ -114,6 +150,8 @@ class EngineRegistryBridgeFacade:
worker_state = self.scheduler_worker.snapshot()
return bool(
prepare_empty
and prepare_text_empty
and prepare_ref_spec_empty
and dispatch_empty
and finalize_empty
and decode_pending_empty

View File

@ -110,6 +110,8 @@ class EngineCompositionBuilder:
get_micro_batch_wait_s=owner.scheduler_worker.get_micro_batch_wait_s,
)
owner.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
owner.engine_prepare_text_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
owner.engine_prepare_ref_spec_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
owner.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
owner.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched")
@ -119,6 +121,8 @@ class EngineCompositionBuilder:
tts=owner.tts,
scheduler_worker=owner.scheduler_worker,
prepare_queue_owner=owner.engine_prepare_queue_owner,
prepare_text_queue_owner=owner.engine_prepare_text_queue_owner,
prepare_ref_spec_queue_owner=owner.engine_prepare_ref_spec_queue_owner,
finalize_queue_owner=owner.engine_finalize_queue_owner,
dispatch_queue_owner=owner.engine_dispatch_queue_owner,
decode_runtime_owner=owner.engine_decode_runtime_owner,

View File

@ -129,7 +129,7 @@ class EnginePolicyArbiterController:
self.state.total_ticks += 1
if stage == "idle":
self.state.total_idle_ticks += 1
elif stage == "prepare":
elif stage in {"prepare", "prepare_audio", "prepare_text", "prepare_ref_spec"}:
self.state.total_prepare_dispatches += 1
self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst)
elif stage == "finalize":
@ -282,7 +282,11 @@ class EnginePolicyArbiterController:
request_registry = self.snapshot_request_registry()
worker_state = self.get_worker_state()
policy_snapshot = self.build_policy_snapshot(request_registry, worker_state)
prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0))
prepare_state = self.snapshot_prepare_state()
prepare_waiting = int(prepare_state.get("waiting_count", 0))
prepare_audio_waiting = int(prepare_state.get("audio_waiting_count", 0))
prepare_text_waiting = int(prepare_state.get("text_waiting_count", 0))
prepare_ref_spec_waiting = int(prepare_state.get("ref_spec_waiting_count", 0))
finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0))
decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0))
decode_runtime_state = self.snapshot_decode_runtime_state()
@ -291,6 +295,9 @@ class EnginePolicyArbiterController:
worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0))
worker_running_requests = int(decode_runtime_state.get("active_request_count", 0))
prepare_age_ms = float(self.peek_queue_age_ms("prepare"))
prepare_audio_age_ms = float(self.peek_queue_age_ms("prepare_audio"))
prepare_text_age_ms = float(self.peek_queue_age_ms("prepare_text"))
prepare_ref_spec_age_ms = float(self.peek_queue_age_ms("prepare_ref_spec"))
finalize_age_ms = float(self.peek_queue_age_ms("finalize"))
decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending"))
decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0))
@ -316,14 +323,31 @@ class EnginePolicyArbiterController:
and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0)
):
return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state
if (
finalize_waiting > 0
and prepare_ref_spec_waiting > 0
and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0)
):
return "prepare_ref_spec", "finalize_waiting_for_ref_spec", policy_snapshot, worker_state
if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0):
return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms):
return "finalize", "finalize_aging", policy_snapshot, worker_state
if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0):
return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms):
return "prepare_text", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0 and prepare_text_waiting <= 0:
return "prepare_ref_spec", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
return "prepare_audio", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms):
return "prepare", "prepare_aging", policy_snapshot, worker_state
if prepare_text_waiting > 0 and prepare_text_age_ms >= max(prepare_audio_age_ms, prepare_age_ms - 1e-6):
return "prepare_text", "prepare_aging", policy_snapshot, worker_state
if (
prepare_ref_spec_waiting > 0
and prepare_ref_spec_age_ms >= max(prepare_audio_age_ms, prepare_text_age_ms, prepare_age_ms - 1e-6)
):
return "prepare_ref_spec", "prepare_aging", policy_snapshot, worker_state
return "prepare_audio", "prepare_aging", policy_snapshot, worker_state
if worker_decode_control_enabled and worker_decode_has_work and policy_allowed:
return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state
if decode_waiting > 0 and policy_allowed:
@ -331,5 +355,9 @@ class EnginePolicyArbiterController:
if finalize_waiting > 0:
return "finalize", "finalize_fallback", policy_snapshot, worker_state
if prepare_waiting > 0:
return "prepare", "prepare_fallback", policy_snapshot, worker_state
if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms):
return "prepare_text", "prepare_fallback", policy_snapshot, worker_state
if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0:
return "prepare_ref_spec", "prepare_fallback", policy_snapshot, worker_state
return "prepare_audio", "prepare_fallback", policy_snapshot, worker_state
return "idle", "no_pending_work", policy_snapshot, worker_state

View File

@ -324,8 +324,24 @@ class EngineGpuPrepareTask:
done_future: asyncio.Future | None
engine_request_id: str | None
enqueue_time: float
queue_wait_ms: float = 0.0
phase: str = "audio"
audio_enqueue_time: float = 0.0
audio_start_time: float = 0.0
audio_end_time: float = 0.0
text_enqueue_time: float = 0.0
text_start_time: float = 0.0
text_end_time: float = 0.0
ref_spec_enqueue_time: float = 0.0
ref_spec_start_time: float = 0.0
ref_spec_end_time: float = 0.0
audio_queue_wait_ms: float = 0.0
text_queue_wait_ms: float = 0.0
ref_spec_queue_wait_ms: float = 0.0
admission_wait_ms: float = 0.0
phase_one: Dict[str, Any] | None = None
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]] | None = None
state_result: T2SRequestState | None = None
cancelled: bool = False
error: str | None = None

View File

@ -14,6 +14,8 @@ class EngineStageOrchestrator:
executor: EngineStageExecutor,
scheduler_worker: UnifiedSchedulerWorker,
prepare_queue_owner: EngineTaskQueueOwner,
prepare_text_queue_owner: EngineTaskQueueOwner,
prepare_ref_spec_queue_owner: EngineTaskQueueOwner,
finalize_queue_owner: EngineTaskQueueOwner,
dispatch_queue_owner: EngineTaskQueueOwner,
decode_runtime_owner: EngineDecodeRuntimeOwner,
@ -22,6 +24,8 @@ class EngineStageOrchestrator:
self.executor = executor
self.scheduler_worker = scheduler_worker
self.prepare_queue_owner = prepare_queue_owner
self.prepare_text_queue_owner = prepare_text_queue_owner
self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner
self.finalize_queue_owner = finalize_queue_owner
self.dispatch_queue_owner = dispatch_queue_owner
self.decode_runtime_owner = decode_runtime_owner
@ -45,7 +49,17 @@ class EngineStageOrchestrator:
def peek_queue_age_ms(self, queue_name: str) -> float:
if queue_name == "prepare":
return max(
self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time"),
self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time"),
self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time"),
)
if queue_name == "prepare_audio":
return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time")
if queue_name == "prepare_text":
return self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time")
if queue_name == "prepare_ref_spec":
return self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time")
if queue_name == "finalize":
return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time")
if queue_name == "decode_runtime_pending":
@ -62,6 +76,10 @@ class EngineStageOrchestrator:
return True
if self.prepare_queue_owner.has_items():
return True
if self.prepare_text_queue_owner.has_items():
return True
if self.prepare_ref_spec_queue_owner.has_items():
return True
if self.finalize_queue_owner.has_items():
return True
return self.dispatch_queue_owner.has_items()
@ -79,6 +97,12 @@ class EngineStageOrchestrator:
executed = False
if stage == "prepare":
executed = self.executor.run_engine_prepare_once()
elif stage == "prepare_audio":
executed = self.executor.run_engine_prepare_audio_once()
elif stage == "prepare_text":
executed = self.executor.run_engine_prepare_text_once()
elif stage == "prepare_ref_spec":
executed = self.executor.run_engine_prepare_ref_spec_once()
elif stage == "finalize":
executed = self.executor.run_engine_finalize_once()
elif stage == "decode_dispatch":

View File

@ -24,6 +24,8 @@ class EngineStageCoordinator:
tts: TTS,
scheduler_worker: UnifiedSchedulerWorker,
prepare_queue_owner: EngineTaskQueueOwner,
prepare_text_queue_owner: EngineTaskQueueOwner,
prepare_ref_spec_queue_owner: EngineTaskQueueOwner,
finalize_queue_owner: EngineTaskQueueOwner,
dispatch_queue_owner: EngineTaskQueueOwner,
decode_runtime_owner: EngineDecodeRuntimeOwner,
@ -45,6 +47,8 @@ class EngineStageCoordinator:
tts=tts,
scheduler_worker=scheduler_worker,
prepare_queue_owner=prepare_queue_owner,
prepare_text_queue_owner=prepare_text_queue_owner,
prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner,
finalize_queue_owner=finalize_queue_owner,
dispatch_queue_owner=dispatch_queue_owner,
decode_runtime_owner=decode_runtime_owner,
@ -66,6 +70,8 @@ class EngineStageCoordinator:
executor=self.executor,
scheduler_worker=scheduler_worker,
prepare_queue_owner=prepare_queue_owner,
prepare_text_queue_owner=prepare_text_queue_owner,
prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner,
finalize_queue_owner=finalize_queue_owner,
dispatch_queue_owner=dispatch_queue_owner,
decode_runtime_owner=decode_runtime_owner,
@ -144,6 +150,15 @@ class EngineStageCoordinator:
def run_engine_prepare_once(self) -> bool:
return self.executor.run_engine_prepare_once()
def run_engine_prepare_audio_once(self) -> bool:
return self.executor.run_engine_prepare_audio_once()
def run_engine_prepare_text_once(self) -> bool:
return self.executor.run_engine_prepare_text_once()
def run_engine_prepare_ref_spec_once(self) -> bool:
return self.executor.run_engine_prepare_ref_spec_once()
def run_engine_finalize_once(self) -> bool:
return self.executor.run_engine_finalize_once()

View File

@ -24,6 +24,10 @@ class EngineDispatchStageMixin:
engine_request_id: str | None,
timeout_sec: float | None,
) -> EngineDispatchTask:
if float(state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0:
error = RuntimeError("ref_spec async stage failed before dispatch")
self.fail_request_state(engine_request_id or state.request_id, str(error))
raise error
task = EngineDispatchTask(
request_id=state.request_id,
state=state,

View File

@ -31,6 +31,8 @@ class EngineStageExecutor(
tts: TTS,
scheduler_worker: UnifiedSchedulerWorker,
prepare_queue_owner: EngineTaskQueueOwner,
prepare_text_queue_owner: EngineTaskQueueOwner,
prepare_ref_spec_queue_owner: EngineTaskQueueOwner,
finalize_queue_owner: EngineTaskQueueOwner,
dispatch_queue_owner: EngineTaskQueueOwner,
decode_runtime_owner: EngineDecodeRuntimeOwner,
@ -51,6 +53,8 @@ class EngineStageExecutor(
self.tts = tts
self.scheduler_worker = scheduler_worker
self.prepare_queue_owner = prepare_queue_owner
self.prepare_text_queue_owner = prepare_text_queue_owner
self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner
self.finalize_queue_owner = finalize_queue_owner
self.dispatch_queue_owner = dispatch_queue_owner
self.decode_runtime_owner = decode_runtime_owner

View File

@ -39,10 +39,37 @@ class EngineFinalizeStageMixin:
tasks = self.take_engine_finalize_batch_nonblocking()
if not tasks:
return False
self.scheduler_worker.begin_finalize_execution(len(tasks))
ready_tasks: List[SchedulerFinalizeTask] = []
failed_tasks: List[SchedulerFinalizeTask] = []
deferred_tasks: List[SchedulerFinalizeTask] = []
for task in tasks:
job = self.get_engine_job(task.request_id)
if job is None:
continue
if float(job.state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0:
failed_tasks.append(task)
continue
if job.state.refer_spec is None:
deferred_tasks.append(task)
self.merge_request_state_profile(
job.engine_request_id or job.request_id,
{
"engine_finalize_ref_spec_blocked": 1.0,
},
)
continue
ready_tasks.append(task)
if deferred_tasks:
self.finalize_queue_owner.enqueue_many(deferred_tasks)
if failed_tasks:
self.fail_engine_jobs([task.request_id for task in failed_tasks], "ref_spec async stage failed")
if not ready_tasks:
self.finalize_queue_owner.mark_completed(len(failed_tasks), notify=True)
return False
self.scheduler_worker.begin_finalize_execution(len(ready_tasks))
try:
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
for task in tasks:
for task in ready_tasks:
job = self.get_engine_job(task.request_id)
if job is None:
continue
@ -50,7 +77,7 @@ class EngineFinalizeStageMixin:
if not jobs_and_items:
return False
now = time.perf_counter()
for task in tasks:
for task in ready_tasks:
job = self.get_engine_job(task.request_id)
if job is not None:
job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0)
@ -69,8 +96,8 @@ class EngineFinalizeStageMixin:
for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results):
self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
except Exception as exc:
self.fail_engine_jobs([task.request_id for task in tasks], str(exc))
self.fail_engine_jobs([task.request_id for task in ready_tasks], str(exc))
finally:
self.scheduler_worker.end_finalize_execution(len(tasks))
self.finalize_queue_owner.mark_completed(len(tasks), notify=True)
self.scheduler_worker.end_finalize_execution(len(ready_tasks))
self.finalize_queue_owner.mark_completed(len(ready_tasks) + len(failed_tasks), notify=True)
return True

View File

@ -10,6 +10,13 @@ from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineGpuPrepare
class EnginePrepareStageMixin:
def _prepare_waiting_total(self) -> int:
return (
int(self.prepare_queue_owner.waiting_count())
+ int(self.prepare_text_queue_owner.waiting_count())
+ int(self.prepare_ref_spec_queue_owner.waiting_count())
)
async def _wait_prepare_queue_admission(self) -> float:
soft_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_SOFT_MAX", "0")))
if soft_max <= 0:
@ -19,7 +26,7 @@ class EnginePrepareStageMixin:
float(max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_ADMISSION_POLL_MS", "1")))) / 1000.0,
)
wait_start = time.perf_counter()
while self.prepare_queue_owner.waiting_count() >= soft_max:
while self._prepare_waiting_total() >= soft_max:
await asyncio.sleep(poll_s)
return max(0.0, (time.perf_counter() - wait_start) * 1000.0)
@ -53,44 +60,247 @@ class EnginePrepareStageMixin:
done_future=done_future,
engine_request_id=engine_request_id or spec.request_id,
enqueue_time=time.perf_counter(),
phase="audio",
audio_enqueue_time=time.perf_counter(),
admission_wait_ms=float(prepare_queue_admission_wait_ms),
)
self.prepare_queue_owner.enqueue(task)
self.notify_arbiter()
return await done_future
def run_engine_prepare_once(self) -> bool:
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
tasks = self.prepare_queue_owner.pop_left_many(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))
def _should_chain_prepare_text_after_audio(self) -> bool:
if str(os.environ.get("GPTSOVITS_ENGINE_PREPARE_CHAIN_TEXT", "1")).strip().lower() in {"0", "false", "no", "off"}:
return False
if self.finalize_queue_owner.has_items() or self.dispatch_queue_owner.has_items():
return False
decode_runtime_state = self.snapshot_engine_decode_runtime_state()
if bool(decode_runtime_state.get("has_work", False)):
return False
return True
def _maybe_apply_ref_spec_to_state(self, task: EngineGpuPrepareTask) -> None:
if task.state_result is None or task.ref_spec_result is None:
return
self.scheduler_worker.apply_ref_spec_result_to_state(task.state_result, task.ref_spec_result)
if task.engine_request_id not in [None, ""]:
self.merge_request_state_profile(
str(task.engine_request_id),
{
"engine_prepare_ref_spec_queue_wait_ms": float(task.ref_spec_queue_wait_ms),
"ref_spec_wait_ms": float(task.ref_spec_result[1].get("ref_spec_wait_ms", 0.0)),
"ref_spec_ms": float(task.ref_spec_result[1].get("ref_spec_ms", 0.0)),
"ref_spec_to_device_ms": float(task.ref_spec_result[1].get("ref_spec_to_device_ms", 0.0)),
"ref_spec_main_resample_ms": float(task.ref_spec_result[1].get("ref_spec_main_resample_ms", 0.0)),
"ref_spec_norm_ms": float(task.ref_spec_result[1].get("ref_spec_norm_ms", 0.0)),
"ref_spec_spectrogram_ms": float(task.ref_spec_result[1].get("ref_spec_spectrogram_ms", 0.0)),
"ref_spec_post_resample_ms": float(task.ref_spec_result[1].get("ref_spec_post_resample_ms", 0.0)),
},
)
def _mark_ref_spec_async_failed(
self,
task: EngineGpuPrepareTask,
error: Exception,
*,
queue_wait_ms: float,
) -> None:
task.error = str(error)
task.cancelled = True
if task.state_result is not None:
task.state_result.prepare_profile["ref_spec_async_failed"] = 1.0
task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms)
if task.engine_request_id not in [None, ""]:
self.merge_request_state_profile(
str(task.engine_request_id),
{
"ref_spec_async_failed": 1.0,
"engine_prepare_ref_spec_queue_wait_ms": float(queue_wait_ms),
},
)
self.fail_request_state(task.engine_request_id or task.request_id, str(error))
self.fail_engine_jobs([task.request_id], str(error))
self.notify_arbiter()
def _run_engine_prepare_audio_once(self, batch_max_items: int) -> bool:
tasks = self.prepare_queue_owner.pop_left_many(batch_max_items)
if not tasks:
return False
now = time.perf_counter()
queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks]
batch_results = asyncio.run(
self.scheduler_worker.prepare_gpu_stages_profiled_async([task.cpu_stage for task in tasks])
)
for task in tasks:
task.audio_start_time = float(now)
batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_audio_phases_async([task.cpu_stage for task in tasks]))
completed_count = 0
for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results):
task.audio_end_time = time.perf_counter()
if isinstance(result, Exception):
task.error = str(result)
self.fail_request_state(task.engine_request_id or task.request_id, str(result))
self._notify_prepare_error(task, result)
completed_count += 1
continue
state, prepare_exec_started_at, prepare_exec_finished_at = result
state.prepare_profile["engine_prepare_queue_admission_wait_ms"] = float(task.admission_wait_ms)
state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms)
task.audio_queue_wait_ms = float(queue_wait_ms)
task.phase_one = result
task.phase = "text"
task.enqueue_time = time.perf_counter()
task.text_enqueue_time = float(task.enqueue_time)
task.ref_spec_enqueue_time = float(task.enqueue_time)
self.prepare_text_queue_owner.enqueue(task)
self.prepare_ref_spec_queue_owner.enqueue(task)
if task.engine_request_id not in [None, ""]:
self.merge_request_state_profile(
str(task.engine_request_id),
{
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
"engine_prepare_audio_queue_wait_ms": float(queue_wait_ms),
"engine_prepare_audio_batch_size": float(len(tasks)),
"engine_prepare_audio_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)),
"engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time),
"engine_prepare_audio_start_ts": float(task.audio_start_time),
"engine_prepare_audio_end_ts": float(task.audio_end_time),
"engine_prepare_text_enqueue_ts": float(task.text_enqueue_time),
"engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time),
},
)
completed_count += 1
self.prepare_queue_owner.mark_completed(completed_count)
if completed_count > 0 and self._should_chain_prepare_text_after_audio():
self._run_engine_prepare_text_once(min(batch_max_items, completed_count))
return True
if completed_count > 0:
self.notify_arbiter()
return True
def _run_engine_prepare_text_once(self, batch_max_items: int) -> bool:
tasks = self.prepare_text_queue_owner.pop_left_many(batch_max_items)
if not tasks:
return False
now = time.perf_counter()
queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks]
for task in tasks:
task.text_start_time = float(now)
items = [(task.cpu_stage, task.phase_one) for task in tasks if task.phase_one is not None]
batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_text_phases_async(items))
completed_count = 0
for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results):
task.text_end_time = time.perf_counter()
if isinstance(result, Exception):
task.error = str(result)
task.cancelled = True
self.fail_request_state(task.engine_request_id or task.request_id, str(result))
self._notify_prepare_error(task, result)
completed_count += 1
continue
task.text_queue_wait_ms = float(queue_wait_ms)
state, prepare_exec_started_at, prepare_exec_finished_at = self.scheduler_worker.build_gpu_prepare_result_from_phases(
task.cpu_stage,
task.phase_one or {},
result,
extra_profile={
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
"engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms),
"engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms),
"engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms),
"engine_prepare_audio_batch_size": float(len(tasks)),
"engine_prepare_text_batch_size": float(len(tasks)),
"engine_prepare_audio_phase_mode": 2.0,
"engine_prepare_audio_phase_wall_ms": float((task.phase_one or {}).get("phase_wall_ms", 0.0)),
"engine_prepare_text_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)),
"engine_prepare_text_phase_batch_size": float(len(tasks)),
"engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time),
"engine_prepare_audio_start_ts": float(task.audio_start_time),
"engine_prepare_audio_end_ts": float(task.audio_end_time),
"engine_prepare_text_enqueue_ts": float(task.text_enqueue_time),
"engine_prepare_text_start_ts": float(task.text_start_time),
"engine_prepare_text_end_ts": float(task.text_end_time),
"engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time),
},
)
task.state_result = state
self._maybe_apply_ref_spec_to_state(task)
state.prepare_profile["engine_gpu_prepare_batch_size"] = float(len(tasks))
if task.engine_request_id not in [None, ""]:
self.merge_request_state_profile(
str(task.engine_request_id),
{
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms),
"engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms),
"engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms),
"engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms),
"engine_gpu_prepare_batch_size": float(len(tasks)),
},
)
self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at))
completed_count += 1
self.prepare_queue_owner.mark_completed(completed_count)
self.prepare_text_queue_owner.mark_completed(completed_count)
return True
def _run_engine_prepare_ref_spec_once(self, batch_max_items: int) -> bool:
tasks = self.prepare_ref_spec_queue_owner.pop_left_many(batch_max_items)
if not tasks:
return False
now = time.perf_counter()
runnable_tasks: list[EngineGpuPrepareTask] = []
queue_wait_ms_list: list[float] = []
completed_count = 0
for task in tasks:
if task.cancelled or task.phase_one is None:
completed_count += 1
continue
task.ref_spec_start_time = float(now)
runnable_tasks.append(task)
queue_wait_ms_list.append(max(0.0, (now - task.ref_spec_enqueue_time) * 1000.0))
if not runnable_tasks:
self.prepare_ref_spec_queue_owner.mark_completed(completed_count)
return True
batch_results = asyncio.run(
self.scheduler_worker.prepare_ref_spec_stages_async([task.phase_one or {} for task in runnable_tasks])
)
for task, queue_wait_ms, result in zip(runnable_tasks, queue_wait_ms_list, batch_results):
task.ref_spec_end_time = time.perf_counter()
task.ref_spec_queue_wait_ms = float(queue_wait_ms)
if isinstance(result, Exception):
self._mark_ref_spec_async_failed(task, result, queue_wait_ms=float(queue_wait_ms))
completed_count += 1
continue
task.ref_spec_result = result
self._maybe_apply_ref_spec_to_state(task)
if task.state_result is not None:
task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms)
task.state_result.prepare_profile["engine_prepare_ref_spec_enqueue_ts"] = float(task.ref_spec_enqueue_time)
task.state_result.prepare_profile["engine_prepare_ref_spec_start_ts"] = float(task.ref_spec_start_time)
task.state_result.prepare_profile["engine_prepare_ref_spec_end_ts"] = float(task.ref_spec_end_time)
completed_count += 1
self.prepare_ref_spec_queue_owner.mark_completed(completed_count)
return True
def run_engine_prepare_once(self) -> bool:
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
batch_max_items = int(prepare_batch_policy.get("prepare_batch_max_items", 1))
audio_age_ms = self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time")
text_age_ms = self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time")
if self.prepare_text_queue_owner.has_items() and (
not self.prepare_queue_owner.has_items() or text_age_ms >= audio_age_ms
):
return self._run_engine_prepare_text_once(batch_max_items)
if self.prepare_queue_owner.has_items():
return self._run_engine_prepare_audio_once(batch_max_items)
if self.prepare_ref_spec_queue_owner.has_items():
return self._run_engine_prepare_ref_spec_once(batch_max_items)
if self.prepare_text_queue_owner.has_items():
return self._run_engine_prepare_text_once(batch_max_items)
if self.prepare_ref_spec_queue_owner.has_items():
return self._run_engine_prepare_ref_spec_once(batch_max_items)
return False
def run_engine_prepare_audio_once(self) -> bool:
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
return self._run_engine_prepare_audio_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))
def run_engine_prepare_text_once(self) -> bool:
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
return self._run_engine_prepare_text_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))
def run_engine_prepare_ref_spec_once(self) -> bool:
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
return self._run_engine_prepare_ref_spec_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))

View File

@ -151,7 +151,9 @@ class WorkerFinalizeExecutor:
@staticmethod
def _collect_job_refer_specs(job: SchedulerPendingJob) -> List[tuple]:
refer_specs = [job.state.refer_spec]
refer_specs = []
if job.state.refer_spec is not None:
refer_specs.append(job.state.refer_spec)
refer_specs.extend(list(getattr(job.state, "aux_refer_specs", []) or []))
return refer_specs

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
import os
import time
from typing import Callable, Dict, List
from typing import Any, Callable, Dict, List
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage
@ -81,11 +81,60 @@ class WorkerPrepareExecutor:
cpu_stages: List[PreparedCpuStage],
) -> List[tuple[T2SRequestState, float, float] | Exception]:
try:
return list(
await asyncio.gather(
*[self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) for cpu_stage in cpu_stages],
return_exceptions=True,
)
return await self.coordinator.prepare_gpu_stages_profiled_async(cpu_stages)
finally:
self._notify_state_change()
async def prepare_gpu_audio_phases_async(
self,
cpu_stages: List[PreparedCpuStage],
) -> List[Dict[str, Any] | Exception]:
try:
return await self.coordinator.prepare_gpu_audio_phases_async(cpu_stages)
finally:
self._notify_state_change()
async def prepare_gpu_text_phases_async(
self,
items: List[tuple[PreparedCpuStage, Dict[str, Any]]],
) -> List[Dict[str, Any] | Exception]:
try:
return await self.coordinator.prepare_gpu_text_phases_async(items)
finally:
self._notify_state_change()
def build_gpu_prepare_result_from_phases(
self,
cpu_stage: PreparedCpuStage,
phase_one: Dict[str, Any],
phase_two: Dict[str, Any],
extra_profile: Dict[str, float] | None = None,
) -> tuple[T2SRequestState, float, float]:
try:
return self.coordinator.build_gpu_prepare_result_from_phases(
cpu_stage,
phase_one,
phase_two,
extra_profile=extra_profile,
)
finally:
self._notify_state_change()
async def prepare_ref_spec_stages_async(
self,
phase_ones: List[Dict[str, Any]],
) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]:
try:
return await self.coordinator.prepare_ref_spec_stages_async(phase_ones)
finally:
self._notify_state_change()
def apply_ref_spec_result_to_state(
self,
state: T2SRequestState,
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]],
) -> None:
try:
self.coordinator.apply_ref_spec_result_to_state(state, ref_spec_result)
finally:
self._notify_state_change()

View File

@ -267,3 +267,42 @@ class WorkerSubmitLifecycleMixin:
cpu_stages: List[PreparedCpuStage],
) -> List[tuple[T2SRequestState, float, float] | Exception]:
return await self.prepare_executor.prepare_gpu_stages_profiled_async(cpu_stages)
async def prepare_gpu_audio_phases_async(
self,
cpu_stages: List[PreparedCpuStage],
) -> List[Dict[str, Any] | Exception]:
return await self.prepare_executor.prepare_gpu_audio_phases_async(cpu_stages)
async def prepare_gpu_text_phases_async(
self,
items: List[tuple[PreparedCpuStage, Dict[str, Any]]],
) -> List[Dict[str, Any] | Exception]:
return await self.prepare_executor.prepare_gpu_text_phases_async(items)
def build_gpu_prepare_result_from_phases(
self,
cpu_stage: PreparedCpuStage,
phase_one: Dict[str, Any],
phase_two: Dict[str, Any],
extra_profile: Dict[str, float] | None = None,
) -> tuple[T2SRequestState, float, float]:
return self.prepare_executor.build_gpu_prepare_result_from_phases(
cpu_stage,
phase_one,
phase_two,
extra_profile=extra_profile,
)
async def prepare_ref_spec_stages_async(
self,
phase_ones: List[Dict[str, Any]],
) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]:
return await self.prepare_executor.prepare_ref_spec_stages_async(phase_ones)
def apply_ref_spec_result_to_state(
self,
state: T2SRequestState,
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]],
) -> None:
self.prepare_executor.apply_ref_spec_result_to_state(state, ref_spec_result)

View File

@ -244,6 +244,16 @@ class G2PWRuntimeWrapper:
)
self.batch_worker.start()
def _sync_runtime_env_overrides(self) -> None:
os.environ["G2PW_ENABLE_CUDA_GRAPH"] = "1" if self.enable_cuda_graph else "0"
os.environ["G2PW_ENABLE_PROFILE"] = "1" if self.enable_profiling else "0"
os.environ["G2PW_DUMP_GRAPH_CACHE_STATS"] = "1" if self.dump_graph_cache_stats else "0"
os.environ["G2PW_FULL_GRAPH_CACHE_LIMIT"] = str(int(self.full_graph_cache_limit))
os.environ["G2PW_TAIL_GRAPH_CACHE_LIMIT"] = str(int(self.tail_graph_cache_limit))
os.environ["G2PW_ALLOW_TENSOR_CORES"] = "1" if self.allow_tensor_cores else "0"
os.environ["G2PW_USE_CUBLASLT_BIAS_EPILOGUE"] = "1" if self.use_cublaslt_bias_epilogue else "0"
os.environ["G2PW_GEMM_PRECISION"] = {0: "fp32", 1: "fp16", 2: "bf16"}.get(int(self.gemm_precision), "fp32")
def _destroy_handle(self) -> None:
if self.handle:
self.lib.g2pw_runtime_destroy(self.handle)
@ -268,6 +278,7 @@ class G2PWRuntimeWrapper:
return "" if not message else message.decode("utf-8", errors="replace")
def _create_handle(self, batch_size: int, seq_len: int) -> None:
self._sync_runtime_env_overrides()
new_handle = self.lib.g2pw_runtime_create(
str(self.manifest_path).encode("utf-8"),
str(self.weights_path).encode("utf-8"),
@ -518,6 +529,10 @@ class G2PWRuntimeWrapper:
return {
"shard_index": int(self.shard_index),
"enabled": bool(self.batch_enabled),
"enable_cuda_graph": bool(self.enable_cuda_graph),
"enable_profiling": bool(self.enable_profiling),
"full_graph_cache_limit": int(self.full_graph_cache_limit),
"tail_graph_cache_limit": int(self.tail_graph_cache_limit),
"window_ms": float(self.batch_window_s * 1000.0),
"max_requests": int(self.batch_max_requests),
"max_rows": int(self.batch_max_rows),