Enhance TTS processing with new reference specification handling and profiling metrics

Refactor the PrepareCoordinator and related components to improve the handling of reference specifications in the TTS system. Introduce new methods for building and extracting reference prompts and specifications, along with detailed profiling metrics for performance monitoring. Update the PrepareRefSemanticBatchWorker to include additional timing metrics and caching mechanisms for resampling. These changes enhance the efficiency and maintainability of the TTS framework, particularly in managing audio processing and reference data.
This commit is contained in:
baicai-1145 2026-03-13 16:45:38 +08:00
parent c94de2f2cb
commit 8a444c10b7
18 changed files with 1006 additions and 174 deletions

View File

@ -996,21 +996,39 @@ class TTS:
return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr) return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr)
def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
spec, audio, _, _, _ = self._extract_ref_spec_profile_from_raw(raw_audio, raw_sr)
return spec, audio, raw_audio, raw_sr
def _extract_ref_spec_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
profile = {
"ref_spec_to_device_ms": 0.0,
"ref_spec_main_resample_ms": 0.0,
"ref_spec_norm_ms": 0.0,
"ref_spec_spectrogram_ms": 0.0,
"ref_spec_post_resample_ms": 0.0,
}
to_device_start = time.perf_counter()
raw_audio_device = raw_audio.to(self.configs.device).float() raw_audio_device = raw_audio.to(self.configs.device).float()
profile["ref_spec_to_device_ms"] = (time.perf_counter() - to_device_start) * 1000.0
if raw_sr != self.configs.sampling_rate: if raw_sr != self.configs.sampling_rate:
resample_start = time.perf_counter()
audio = raw_audio_device audio = raw_audio_device
if audio.shape[0] == 2: if audio.shape[0] == 2:
audio = audio.mean(0).unsqueeze(0) audio = audio.mean(0).unsqueeze(0)
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device) audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
profile["ref_spec_main_resample_ms"] = (time.perf_counter() - resample_start) * 1000.0
else: else:
audio = raw_audio_device audio = raw_audio_device
if audio.shape[0] == 2: if audio.shape[0] == 2:
audio = audio.mean(0).unsqueeze(0) audio = audio.mean(0).unsqueeze(0)
norm_start = time.perf_counter()
maxx = audio.abs().max() maxx = audio.abs().max()
if maxx > 1: if maxx > 1:
audio /= min(2, maxx) audio /= min(2, maxx)
profile["ref_spec_norm_ms"] = (time.perf_counter() - norm_start) * 1000.0
spec_start = time.perf_counter()
spec = spectrogram_torch( spec = spectrogram_torch(
audio, audio,
self.configs.filter_length, self.configs.filter_length,
@ -1019,15 +1037,18 @@ class TTS:
self.configs.win_length, self.configs.win_length,
center=False, center=False,
) )
profile["ref_spec_spectrogram_ms"] = (time.perf_counter() - spec_start) * 1000.0
if self.configs.is_half: if self.configs.is_half:
spec = spec.half() spec = spec.half()
if self.is_v2pro == True: if self.is_v2pro == True:
post_resample_start = time.perf_counter()
audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device) audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device)
profile["ref_spec_post_resample_ms"] = (time.perf_counter() - post_resample_start) * 1000.0
if self.configs.is_half: if self.configs.is_half:
audio = audio.half() audio = audio.half()
else: else:
audio = None audio = None
return spec, audio, raw_audio, raw_sr return spec, audio, raw_audio, raw_sr, profile
def extract_ref_spec(self, ref_audio_path: str): def extract_ref_spec(self, ref_audio_path: str):
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)

View File

@ -228,8 +228,74 @@ class PrepareCoordinator:
def _load_ref_audio_raw(self, ref_audio_path: str): def _load_ref_audio_raw(self, ref_audio_path: str):
return self.tts._load_ref_audio_raw(ref_audio_path) return self.tts._load_ref_audio_raw(ref_audio_path)
def _build_ref_prompt_semantic_from_raw(self, raw_audio, raw_sr: int):
load_profile = {"audio_load_ms": 0.0}
if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None:
prompt_semantic, worker_profile = self.tts.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr)
return {
"prompt_semantic": prompt_semantic,
"raw_audio": raw_audio,
"raw_sr": raw_sr,
"profile": {
**load_profile,
"audio_stage_wait_ms": float(worker_profile.get("prompt_semantic_wait_ms", 0.0)),
"audio_stage_slots": float(worker_profile.get("prompt_semantic_stage_slots", 0.0)),
"audio_stage_inflight_peak": float(worker_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
"prompt_semantic_ms": float(
worker_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
+ worker_profile.get("prompt_semantic_forward_ms", 0.0)
+ worker_profile.get("prompt_semantic_scatter_ms", 0.0)
),
**{key: float(value) for key, value in worker_profile.items()},
"ref_spec_wait_ms": 0.0,
"ref_spec_ms": 0.0,
"bundle_total_ms": float(worker_profile.get("prompt_semantic_wait_ms", 0.0))
+ float(worker_profile.get("prompt_semantic_cpu_prepare_ms", 0.0))
+ float(worker_profile.get("prompt_semantic_forward_ms", 0.0))
+ float(worker_profile.get("prompt_semantic_scatter_ms", 0.0)),
},
}
wav16k, cpu_prepare_ms, limiter_stats = self.tts._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr)
with self.tts.prepare_ref_audio_stage_limiter.enter() as stage_stats:
prompt_semantic, forward_ms = self.tts._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k)
return {
"prompt_semantic": prompt_semantic,
"raw_audio": raw_audio,
"raw_sr": raw_sr,
"profile": {
"audio_load_ms": 0.0,
"audio_stage_wait_ms": float(stage_stats.get("wait_ms", 0.0)),
"audio_stage_slots": float(stage_stats.get("slots", 0.0)),
"audio_stage_inflight_peak": float(stage_stats.get("peak_inflight", 0.0)),
"prompt_semantic_wait_ms": float(stage_stats.get("wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_wait_ms": float(limiter_stats.get("wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_slots": float(limiter_stats.get("slots", 0.0)),
"prompt_semantic_cpu_prepare_inflight_peak": float(limiter_stats.get("peak_inflight", 0.0)),
"prompt_semantic_worker_queue_wait_ms": 0.0,
"prompt_semantic_batch_collect_wait_ms": 0.0,
"prompt_semantic_stage_limiter_wait_ms": float(stage_stats.get("wait_ms", 0.0)),
"prompt_semantic_batch_dispatch_delay_ms": 0.0,
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
"prompt_semantic_pack_ms": 0.0,
"prompt_semantic_h2d_ms": 0.0,
"prompt_semantic_ssl_forward_ms": 0.0,
"prompt_semantic_hidden_length_ms": 0.0,
"prompt_semantic_extract_latent_ms": 0.0,
"prompt_semantic_forward_ms": float(forward_ms),
"prompt_semantic_scatter_ms": 0.0,
"prompt_semantic_stage_slots": float(stage_stats.get("slots", 0.0)),
"prompt_semantic_stage_inflight_peak": float(stage_stats.get("peak_inflight", 0.0)),
"prompt_semantic_batch_size": 1.0,
"prompt_semantic_batch_samples": 0.0,
"ref_spec_wait_ms": 0.0,
"ref_spec_ms": 0.0,
"bundle_total_ms": float(cpu_prepare_ms + forward_ms + stage_stats.get("wait_ms", 0.0)),
},
}
def _extract_ref_spec_from_raw(self, raw_audio, raw_sr: int): def _extract_ref_spec_from_raw(self, raw_audio, raw_sr: int):
return self.tts._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] spec, audio, _, _, profile = self.tts._extract_ref_spec_profile_from_raw(raw_audio, raw_sr)
return (spec, audio), profile
@staticmethod @staticmethod
def _build_empty_text_features_like(reference: PreparedTextFeatures | None = None) -> PreparedTextFeatures: def _build_empty_text_features_like(reference: PreparedTextFeatures | None = None) -> PreparedTextFeatures:
@ -523,7 +589,7 @@ class PrepareCoordinator:
finally: finally:
self.text_feature_gate.release() self.text_feature_gate.release()
async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult: async def _run_ref_prompt_semantic_stage(self, ref_audio_path: str) -> ProfiledResult:
if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None: if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None:
submit_at = time.perf_counter() submit_at = time.perf_counter()
started_at = float(submit_at) started_at = float(submit_at)
@ -538,19 +604,7 @@ class PrepareCoordinator:
prompt_semantic_task = asyncio.create_task( prompt_semantic_task = asyncio.create_task(
self.tts.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr) self.tts.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr)
) )
await self.ref_spec_gate.acquire() prompt_semantic, prompt_semantic_profile = await prompt_semantic_task
try:
ref_spec_task = asyncio.create_task(
self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr)
)
(prompt_semantic, prompt_semantic_profile), ref_spec_profiled = await asyncio.gather(
prompt_semantic_task,
ref_spec_task,
)
finally:
self.ref_spec_gate.release()
refer_spec = ref_spec_profiled.result
limiter_snapshot = ( limiter_snapshot = (
self.tts.prepare_ref_audio_stage_limiter.snapshot() self.tts.prepare_ref_audio_stage_limiter.snapshot()
if getattr(self.tts, "prepare_ref_audio_stage_limiter", None) is not None if getattr(self.tts, "prepare_ref_audio_stage_limiter", None) is not None
@ -561,21 +615,15 @@ class PrepareCoordinator:
+ float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0))
+ float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0))
) )
audio_stage_wait_ms = (
float(load_profiled.queue_ms)
+ float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0))
+ float(ref_spec_profiled.queue_ms)
)
finished_at = time.perf_counter() finished_at = time.perf_counter()
result = { result = {
"prompt_semantic": prompt_semantic, "prompt_semantic": prompt_semantic,
"refer_spec": refer_spec,
"raw_audio": raw_audio, "raw_audio": raw_audio,
"raw_sr": raw_sr, "raw_sr": raw_sr,
"profile": { "profile": {
"audio_load_queue_ms": float(load_profiled.queue_ms), "audio_load_queue_ms": float(load_profiled.queue_ms),
"audio_load_ms": float(load_profiled.run_ms), "audio_load_ms": float(load_profiled.run_ms),
"audio_stage_wait_ms": float(audio_stage_wait_ms), "audio_stage_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"audio_stage_slots": float( "audio_stage_slots": float(
max( max(
float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
@ -593,6 +641,17 @@ class PrepareCoordinator:
"prompt_semantic_cpu_prepare_ms": float( "prompt_semantic_cpu_prepare_ms": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
), ),
"prompt_semantic_pack_ms": float(prompt_semantic_profile.get("prompt_semantic_pack_ms", 0.0)),
"prompt_semantic_h2d_ms": float(prompt_semantic_profile.get("prompt_semantic_h2d_ms", 0.0)),
"prompt_semantic_ssl_forward_ms": float(
prompt_semantic_profile.get("prompt_semantic_ssl_forward_ms", 0.0)
),
"prompt_semantic_hidden_length_ms": float(
prompt_semantic_profile.get("prompt_semantic_hidden_length_ms", 0.0)
),
"prompt_semantic_extract_latent_ms": float(
prompt_semantic_profile.get("prompt_semantic_extract_latent_ms", 0.0)
),
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),
"prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
@ -603,14 +662,10 @@ class PrepareCoordinator:
"prompt_semantic_batch_samples": float( "prompt_semantic_batch_samples": float(
prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)
), ),
"ref_spec_wait_ms": float(ref_spec_profiled.queue_ms),
"ref_spec_ms": float(ref_spec_profiled.run_ms),
"bundle_total_ms": float( "bundle_total_ms": float(
load_profiled.queue_ms load_profiled.queue_ms
+ load_profiled.run_ms + load_profiled.run_ms
+ prompt_semantic_ms + prompt_semantic_ms
+ ref_spec_profiled.queue_ms
+ ref_spec_profiled.run_ms
), ),
}, },
} }
@ -623,21 +678,26 @@ class PrepareCoordinator:
await self.ref_audio_gate.acquire() await self.ref_audio_gate.acquire()
try: try:
if hasattr(self.tts, "extract_ref_audio_bundle_async"): load_profiled = await self._run_on_executor(self.ref_audio_executor, self._load_ref_audio_raw, ref_audio_path)
submit_at = time.perf_counter() raw_audio, raw_sr = load_profiled.result
started_at = time.perf_counter() submit_at = time.perf_counter()
result = await self.tts.extract_ref_audio_bundle_async(ref_audio_path) started_at = time.perf_counter()
finished_at = time.perf_counter() result = await asyncio.to_thread(self._build_ref_prompt_semantic_from_raw, raw_audio, raw_sr)
return ProfiledResult( result.setdefault("profile", {})
result=result, result["profile"]["audio_load_queue_ms"] = float(load_profiled.queue_ms)
submit_at=float(submit_at), result["profile"]["audio_load_ms"] = float(load_profiled.run_ms)
started_at=float(started_at), finished_at = time.perf_counter()
finished_at=float(finished_at), return ProfiledResult(result=result, submit_at=float(submit_at), started_at=float(started_at), finished_at=float(finished_at))
)
return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path)
finally: finally:
self.ref_audio_gate.release() self.ref_audio_gate.release()
async def _run_ref_spec_stage(self, raw_audio, raw_sr: int) -> ProfiledResult:
await self.ref_spec_gate.acquire()
try:
return await self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr)
finally:
self.ref_spec_gate.release()
def _release_split_stage_slot(self) -> None: def _release_split_stage_slot(self) -> None:
self._mark_leave() self._mark_leave()
self._inflight_gate.release() self._inflight_gate.release()
@ -682,101 +742,318 @@ class PrepareCoordinator:
cpu_stage: PreparedCpuStage, cpu_stage: PreparedCpuStage,
) -> tuple[T2SRequestState, float, float]: ) -> tuple[T2SRequestState, float, float]:
try: try:
g2pw_pair_start = time.perf_counter() phase_one = await self._prepare_gpu_phase_one(cpu_stage)
g2pw_pair_task = asyncio.create_task( phase_two = await self._prepare_gpu_phase_two(cpu_stage, phase_one)
self._run_g2pw_pair_stage( return self._build_gpu_prepare_result(
cpu_stage.prompt_cpu_profiled.result, cpu_stage,
cpu_stage.target_cpu_profiled.result, phase_one,
) phase_two,
) extra_profile={
ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(cpu_stage.spec.ref_audio_path))) "engine_prepare_audio_phase_mode": 0.0,
(prompt_g2pw_profiled, target_g2pw_profiled), ref_audio_profiled = await asyncio.gather( "engine_prepare_audio_phase_wall_ms": float(phase_one["phase_wall_ms"]),
g2pw_pair_task, "engine_prepare_audio_phase_batch_size": 1.0,
ref_audio_task, "engine_prepare_text_phase_wall_ms": float(phase_two["phase_wall_ms"]),
) "engine_prepare_text_phase_batch_size": 1.0,
g2pw_pair_end = time.perf_counter()
text_pair_start = time.perf_counter()
text_feature_pair_task = asyncio.create_task(
self._run_text_feature_pair_stage(
prompt_g2pw_profiled.result,
target_g2pw_profiled.result,
cpu_stage.prompt_cpu_profiled.run_ms,
cpu_stage.target_cpu_profiled.run_ms,
prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}),
target_base_profile=dict(target_g2pw_profiled.profile or {}),
)
)
prompt_feature_profiled, target_feature_profiled = await text_feature_pair_task
text_pair_end = time.perf_counter()
state = build_request_state_from_parts(
tts=self.tts,
spec=cpu_stage.spec,
prompt_text=cpu_stage.prompt_text,
text=cpu_stage.text,
prompt_result=prompt_feature_profiled.result,
target_result=target_feature_profiled.result,
ref_audio_bundle=ref_audio_profiled.result,
prepare_start=cpu_stage.prepare_start,
prepare_sync_start=cpu_stage.prepare_start,
profile_overrides={
"executor_queue_ms": max(0.0, (cpu_stage.prepare_start - cpu_stage.prepare_submit_at) * 1000.0),
"prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms,
"executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0),
"text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0),
"g2pw_pair_ms": max(0.0, (g2pw_pair_end - g2pw_pair_start) * 1000.0),
"prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms,
"prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms,
"prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)),
"prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)),
"prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)),
"text_g2pw_queue_ms": target_g2pw_profiled.queue_ms,
"text_g2pw_run_ms": target_g2pw_profiled.run_ms,
"text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)),
"text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)),
"text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)),
"prompt_text_parallel_future_wait_ms": 0.0,
"prompt_text_parallel_future_executor_queue_ms": 0.0,
"prompt_text_parallel_future_run_ms": 0.0,
"prompt_text_parallel_future_finish_after_submit_ms": 0.0,
"prompt_text_parallel_future_queue_tail_after_target_ms": 0.0,
"prompt_text_parallel_future_run_tail_after_target_ms": 0.0,
"prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms,
"prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms,
"prompt_text_cpu_admission_wait_ms": float(
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0)
),
"prompt_text_cpu_backpressure_wait_ms": float(
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0)
),
"prompt_text_cpu_capacity_wait_ms": float(
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0)
),
"prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms,
"prompt_text_feature_run_ms": prompt_feature_profiled.run_ms,
"text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms,
"text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms,
"text_cpu_admission_wait_ms": float(
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0)
),
"text_cpu_backpressure_wait_ms": float(
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0)
),
"text_cpu_capacity_wait_ms": float(
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0)
),
"text_feature_queue_ms": target_feature_profiled.queue_ms,
"text_feature_run_ms": target_feature_profiled.run_ms,
"ref_audio_task_queue_ms": ref_audio_profiled.queue_ms,
"ref_audio_task_run_ms": ref_audio_profiled.run_ms,
"worker_prepare_inflight_on_enter": float(cpu_stage.current_inflight),
"worker_prepare_peak_inflight": float(cpu_stage.peak_inflight),
}, },
) )
prepare_exec_finished_at = time.perf_counter() finally:
state.prepare_profile["executor_run_wall_ms"] = max( self._release_split_stage_slot()
0.0, (prepare_exec_finished_at - cpu_stage.prepare_start) * 1000.0
async def _prepare_gpu_phase_one(self, cpu_stage: PreparedCpuStage) -> Dict[str, Any]:
phase_start = time.perf_counter()
g2pw_pair_task = asyncio.create_task(
self._run_g2pw_pair_stage(
cpu_stage.prompt_cpu_profiled.result,
cpu_stage.target_cpu_profiled.result,
) )
return state, cpu_stage.prepare_start, prepare_exec_finished_at )
ref_audio_task = asyncio.create_task(self._run_ref_prompt_semantic_stage(str(cpu_stage.spec.ref_audio_path)))
prompt_g2pw_profiled, target_g2pw_profiled = await g2pw_pair_task
g2pw_pair_end = time.perf_counter()
ref_audio_profiled = await ref_audio_task
phase_end = time.perf_counter()
return {
"prompt_g2pw_profiled": prompt_g2pw_profiled,
"target_g2pw_profiled": target_g2pw_profiled,
"ref_audio_profiled": ref_audio_profiled,
"ref_spec_result": None,
"g2pw_pair_ms": max(0.0, (g2pw_pair_end - phase_start) * 1000.0),
"phase_wall_ms": max(0.0, (phase_end - phase_start) * 1000.0),
}
async def _prepare_gpu_phase_two(
self,
cpu_stage: PreparedCpuStage,
phase_one: Dict[str, Any],
) -> Dict[str, Any]:
phase_start = time.perf_counter()
prompt_g2pw_profiled = phase_one["prompt_g2pw_profiled"]
target_g2pw_profiled = phase_one["target_g2pw_profiled"]
prompt_feature_profiled, target_feature_profiled = await self._run_text_feature_pair_stage(
prompt_g2pw_profiled.result,
target_g2pw_profiled.result,
cpu_stage.prompt_cpu_profiled.run_ms,
cpu_stage.target_cpu_profiled.run_ms,
prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}),
target_base_profile=dict(target_g2pw_profiled.profile or {}),
)
phase_end = time.perf_counter()
return {
"prompt_feature_profiled": prompt_feature_profiled,
"target_feature_profiled": target_feature_profiled,
"phase_wall_ms": max(0.0, (phase_end - phase_start) * 1000.0),
}
def _build_gpu_prepare_result(
self,
cpu_stage: PreparedCpuStage,
phase_one: Dict[str, Any],
phase_two: Dict[str, Any],
extra_profile: Dict[str, float] | None = None,
) -> tuple[T2SRequestState, float, float]:
prompt_g2pw_profiled = phase_one["prompt_g2pw_profiled"]
target_g2pw_profiled = phase_one["target_g2pw_profiled"]
ref_audio_profiled = phase_one["ref_audio_profiled"]
ref_spec_result = phase_one.get("ref_spec_result")
prompt_feature_profiled = phase_two["prompt_feature_profiled"]
target_feature_profiled = phase_two["target_feature_profiled"]
profile_overrides = {
"executor_queue_ms": max(0.0, (cpu_stage.prepare_start - cpu_stage.prepare_submit_at) * 1000.0),
"prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms,
"prepare_submit_ts": float(cpu_stage.prepare_submit_at),
"prepare_cpu_start_ts": float(cpu_stage.prepare_start),
"prepare_cpu_done_ts": float(
max(cpu_stage.prompt_cpu_profiled.finished_at, cpu_stage.target_cpu_profiled.finished_at)
),
"prompt_text_cpu_start_ts": float(cpu_stage.prompt_cpu_profiled.started_at),
"prompt_text_cpu_end_ts": float(cpu_stage.prompt_cpu_profiled.finished_at),
"text_cpu_start_ts": float(cpu_stage.target_cpu_profiled.started_at),
"text_cpu_end_ts": float(cpu_stage.target_cpu_profiled.finished_at),
"executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0),
"text_feature_pair_ms": float(phase_two["phase_wall_ms"]),
"g2pw_pair_ms": float(phase_one["g2pw_pair_ms"]),
"prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms,
"prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms,
"prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)),
"prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)),
"prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)),
"text_g2pw_queue_ms": target_g2pw_profiled.queue_ms,
"text_g2pw_run_ms": target_g2pw_profiled.run_ms,
"text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)),
"text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)),
"text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)),
"prompt_text_parallel_future_wait_ms": 0.0,
"prompt_text_parallel_future_executor_queue_ms": 0.0,
"prompt_text_parallel_future_run_ms": 0.0,
"prompt_text_parallel_future_finish_after_submit_ms": 0.0,
"prompt_text_parallel_future_queue_tail_after_target_ms": 0.0,
"prompt_text_parallel_future_run_tail_after_target_ms": 0.0,
"prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms,
"prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms,
"prompt_text_cpu_admission_wait_ms": float(
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0)
),
"prompt_text_cpu_backpressure_wait_ms": float(
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0)
),
"prompt_text_cpu_capacity_wait_ms": float(
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0)
),
"prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms,
"prompt_text_feature_run_ms": prompt_feature_profiled.run_ms,
"text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms,
"text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms,
"text_cpu_admission_wait_ms": float(
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0)
),
"text_cpu_backpressure_wait_ms": float(
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0)
),
"text_cpu_capacity_wait_ms": float(
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0)
),
"text_feature_queue_ms": target_feature_profiled.queue_ms,
"text_feature_run_ms": target_feature_profiled.run_ms,
"ref_audio_task_queue_ms": ref_audio_profiled.queue_ms,
"ref_audio_task_run_ms": ref_audio_profiled.run_ms,
"worker_prepare_inflight_on_enter": float(cpu_stage.current_inflight),
"worker_prepare_peak_inflight": float(cpu_stage.peak_inflight),
}
if extra_profile:
profile_overrides.update({key: float(value) for key, value in extra_profile.items()})
ref_audio_bundle = dict(ref_audio_profiled.result)
ref_audio_profile = dict(ref_audio_bundle.get("profile", {}))
if ref_spec_result is not None:
refer_spec, ref_spec_profiled = ref_spec_result
ref_audio_bundle["refer_spec"] = refer_spec
ref_audio_profile.update(
{
"ref_spec_wait_ms": float(ref_spec_profiled.get("ref_spec_wait_ms", 0.0)),
"ref_spec_ms": float(ref_spec_profiled.get("ref_spec_ms", 0.0)),
"ref_spec_to_device_ms": float(ref_spec_profiled.get("ref_spec_to_device_ms", 0.0)),
"ref_spec_main_resample_ms": float(ref_spec_profiled.get("ref_spec_main_resample_ms", 0.0)),
"ref_spec_norm_ms": float(ref_spec_profiled.get("ref_spec_norm_ms", 0.0)),
"ref_spec_spectrogram_ms": float(ref_spec_profiled.get("ref_spec_spectrogram_ms", 0.0)),
"ref_spec_post_resample_ms": float(ref_spec_profiled.get("ref_spec_post_resample_ms", 0.0)),
}
)
else:
ref_audio_bundle["refer_spec"] = None
ref_audio_profile.setdefault("ref_spec_wait_ms", 0.0)
ref_audio_profile.setdefault("ref_spec_ms", 0.0)
ref_audio_profile.setdefault("ref_spec_to_device_ms", 0.0)
ref_audio_profile.setdefault("ref_spec_main_resample_ms", 0.0)
ref_audio_profile.setdefault("ref_spec_norm_ms", 0.0)
ref_audio_profile.setdefault("ref_spec_spectrogram_ms", 0.0)
ref_audio_profile.setdefault("ref_spec_post_resample_ms", 0.0)
ref_audio_bundle["profile"] = ref_audio_profile
state = build_request_state_from_parts(
tts=self.tts,
spec=cpu_stage.spec,
prompt_text=cpu_stage.prompt_text,
text=cpu_stage.text,
prompt_result=prompt_feature_profiled.result,
target_result=target_feature_profiled.result,
ref_audio_bundle=ref_audio_bundle,
prepare_start=cpu_stage.prepare_start,
prepare_sync_start=cpu_stage.prepare_start,
profile_overrides=profile_overrides,
)
prepare_exec_finished_at = time.perf_counter()
state.prepare_profile["executor_run_wall_ms"] = max(0.0, (prepare_exec_finished_at - cpu_stage.prepare_start) * 1000.0)
return state, cpu_stage.prepare_start, prepare_exec_finished_at
async def prepare_ref_spec_stages_async(
self,
phase_ones: list[Dict[str, Any]],
) -> list[tuple[tuple[Any, Any], Dict[str, float]] | Exception]:
async def _one(phase_one: Dict[str, Any]):
ref_audio_profiled = phase_one["ref_audio_profiled"]
raw_audio = ref_audio_profiled.result["raw_audio"]
raw_sr = int(ref_audio_profiled.result["raw_sr"])
profiled = await self._run_ref_spec_stage(raw_audio, raw_sr)
refer_spec, profile = profiled.result
merged_profile = dict(profile)
merged_profile["ref_spec_wait_ms"] = float(profiled.queue_ms)
merged_profile["ref_spec_ms"] = float(profiled.run_ms)
return refer_spec, merged_profile
if not phase_ones:
return []
return list(await asyncio.gather(*[_one(phase_one) for phase_one in phase_ones], return_exceptions=True))
def apply_ref_spec_result_to_state(
self,
state: T2SRequestState,
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]],
) -> None:
refer_spec, profile = ref_spec_result
state.refer_spec = refer_spec
state.prepare_profile["ref_spec_wait_ms"] = float(profile.get("ref_spec_wait_ms", 0.0))
state.prepare_profile["ref_spec_ms"] = float(profile.get("ref_spec_ms", 0.0))
state.prepare_profile["ref_spec_to_device_ms"] = float(profile.get("ref_spec_to_device_ms", 0.0))
state.prepare_profile["ref_spec_main_resample_ms"] = float(profile.get("ref_spec_main_resample_ms", 0.0))
state.prepare_profile["ref_spec_norm_ms"] = float(profile.get("ref_spec_norm_ms", 0.0))
state.prepare_profile["ref_spec_spectrogram_ms"] = float(profile.get("ref_spec_spectrogram_ms", 0.0))
state.prepare_profile["ref_spec_post_resample_ms"] = float(profile.get("ref_spec_post_resample_ms", 0.0))
async def prepare_gpu_stages_profiled_async(
self,
cpu_stages: list[PreparedCpuStage],
) -> list[tuple[T2SRequestState, float, float] | Exception]:
if not cpu_stages:
return []
if len(cpu_stages) == 1:
single_stage = cpu_stages[0]
try:
return [await self.prepare_gpu_stage_profiled_async(single_stage)]
except Exception as exc: # noqa: PERF203
return [exc]
phase_one_started_at = time.perf_counter()
phase_one_results = await asyncio.gather(
*[self._prepare_gpu_phase_one(cpu_stage) for cpu_stage in cpu_stages],
return_exceptions=True,
)
phase_one_finished_at = time.perf_counter()
phase_one_wall_ms = max(0.0, (phase_one_finished_at - phase_one_started_at) * 1000.0)
outputs: list[tuple[T2SRequestState, float, float] | Exception | None] = [None] * len(cpu_stages)
pending_phase_two: list[tuple[int, PreparedCpuStage, Dict[str, Any]]] = []
for index, (cpu_stage, phase_one) in enumerate(zip(cpu_stages, phase_one_results)):
if isinstance(phase_one, Exception):
outputs[index] = phase_one
self._release_split_stage_slot()
continue
pending_phase_two.append((index, cpu_stage, phase_one))
phase_two_started_at = time.perf_counter()
phase_two_results = await asyncio.gather(
*[self._prepare_gpu_phase_two(cpu_stage, phase_one) for _, cpu_stage, phase_one in pending_phase_two],
return_exceptions=True,
)
phase_two_finished_at = time.perf_counter()
phase_two_wall_ms = max(0.0, (phase_two_finished_at - phase_two_started_at) * 1000.0)
for (index, cpu_stage, phase_one), phase_two in zip(pending_phase_two, phase_two_results):
try:
if isinstance(phase_two, Exception):
outputs[index] = phase_two
continue
outputs[index] = self._build_gpu_prepare_result(
cpu_stage,
phase_one,
phase_two,
extra_profile={
"engine_prepare_audio_phase_mode": 1.0,
"engine_prepare_audio_phase_wall_ms": float(phase_one_wall_ms),
"engine_prepare_audio_phase_batch_size": float(len(cpu_stages)),
"engine_prepare_text_phase_wall_ms": float(phase_two_wall_ms),
"engine_prepare_text_phase_batch_size": float(len(pending_phase_two)),
},
)
except Exception as exc: # noqa: PERF203
outputs[index] = exc
finally:
self._release_split_stage_slot()
return [item if item is not None else RuntimeError("prepare batch result missing") for item in outputs]
async def prepare_gpu_audio_phases_async(
self,
cpu_stages: list[PreparedCpuStage],
) -> list[Dict[str, Any] | Exception]:
if not cpu_stages:
return []
return list(
await asyncio.gather(
*[self._prepare_gpu_phase_one(cpu_stage) for cpu_stage in cpu_stages],
return_exceptions=True,
)
)
async def prepare_gpu_text_phases_async(
self,
items: list[tuple[PreparedCpuStage, Dict[str, Any]]],
) -> list[Dict[str, Any] | Exception]:
if not items:
return []
return list(
await asyncio.gather(
*[self._prepare_gpu_phase_two(cpu_stage, phase_one) for cpu_stage, phase_one in items],
return_exceptions=True,
)
)
def build_gpu_prepare_result_from_phases(
self,
cpu_stage: PreparedCpuStage,
phase_one: Dict[str, Any],
phase_two: Dict[str, Any],
extra_profile: Dict[str, float] | None = None,
) -> tuple[T2SRequestState, float, float]:
try:
return self._build_gpu_prepare_result(cpu_stage, phase_one, phase_two, extra_profile=extra_profile)
finally: finally:
self._release_split_stage_slot() self._release_split_stage_slot()

View File

@ -15,6 +15,7 @@ REF_AUDIO_MIN_SAMPLES_16K = 48000
REF_AUDIO_MAX_SAMPLES_16K = 160000 REF_AUDIO_MAX_SAMPLES_16K = 160000
_RESAMPLE_CACHE_LOCK = threading.Lock() _RESAMPLE_CACHE_LOCK = threading.Lock()
_RESAMPLE_CACHE: Dict[Tuple[int, int, str], torchaudio.transforms.Resample] = {} _RESAMPLE_CACHE: Dict[Tuple[int, int, str], torchaudio.transforms.Resample] = {}
_RESAMPLE_STREAM_CACHE: Dict[str, torch.cuda.Stream] = {}
def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.transforms.Resample: def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.transforms.Resample:
@ -28,6 +29,16 @@ def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.tran
return transform return transform
def _get_resample_stream(device: str) -> torch.cuda.Stream:
device_key = str(device)
with _RESAMPLE_CACHE_LOCK:
stream = _RESAMPLE_STREAM_CACHE.get(device_key)
if stream is None:
stream = torch.cuda.Stream(device=device_key)
_RESAMPLE_STREAM_CACHE[device_key] = stream
return stream
def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor: def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor:
resample_device = os.environ.get("GPTSOVITS_PREPARE_REF_RESAMPLE_DEVICE", "cpu").strip().lower() or "cpu" resample_device = os.environ.get("GPTSOVITS_PREPARE_REF_RESAMPLE_DEVICE", "cpu").strip().lower() or "cpu"
if resample_device not in {"cpu", "cuda"}: if resample_device not in {"cpu", "cuda"}:
@ -37,10 +48,20 @@ def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wa
wav_mono = raw_audio wav_mono = raw_audio
if wav_mono.dim() == 2 and wav_mono.shape[0] != 1: if wav_mono.dim() == 2 and wav_mono.shape[0] != 1:
wav_mono = wav_mono.mean(0, keepdim=True) wav_mono = wav_mono.mean(0, keepdim=True)
wav16k = wav_mono.to(dtype=torch.float32, device=resample_device) if resample_device == "cuda":
if raw_sr != 16000: stream = _get_resample_stream(resample_device)
wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k) with torch.cuda.stream(stream):
wav16k = wav16k.squeeze(0).contiguous() wav16k = wav_mono.to(dtype=torch.float32, device=resample_device)
if raw_sr != 16000:
wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k)
wav16k = wav16k.squeeze(0).contiguous()
stream.synchronize()
wav16k = wav16k.detach().to(device="cpu", dtype=torch.float32).contiguous()
else:
wav16k = wav_mono.to(dtype=torch.float32, device=resample_device)
if raw_sr != 16000:
wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k)
wav16k = wav16k.squeeze(0).contiguous()
if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K: if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K:
raise OSError("参考音频在3~10秒范围外请更换") raise OSError("参考音频在3~10秒范围外请更换")
if zero_wav_samples > 0: if zero_wav_samples > 0:
@ -256,37 +277,56 @@ class PrepareRefSemanticBatchWorker:
batch_samples = int(wav_lengths.sum().item()) batch_samples = int(wav_lengths.sum().item())
max_wav_len = int(wav_lengths.max().item()) max_wav_len = int(wav_lengths.max().item())
pack_start = time.perf_counter()
input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32) input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32)
attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long) attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long)
for batch_index, wav in enumerate(prepared_wavs): for batch_index, wav in enumerate(prepared_wavs):
wav_len = int(wav.shape[0]) wav_len = int(wav.shape[0])
input_values_cpu[batch_index, :wav_len] = wav input_values_cpu[batch_index, :wav_len] = wav
attention_mask_cpu[batch_index, :wav_len] = 1 attention_mask_cpu[batch_index, :wav_len] = 1
pack_ms = (time.perf_counter() - pack_start) * 1000.0
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
h2d_ms = 0.0
ssl_forward_ms = 0.0
hidden_length_ms = 0.0
extract_latent_ms = 0.0
if self.stage_limiter is None: if self.stage_limiter is None:
h2d_start = time.perf_counter()
input_values = input_values_cpu.to(self.device) input_values = input_values_cpu.to(self.device)
attention_mask = attention_mask_cpu.to(self.device) attention_mask = attention_mask_cpu.to(self.device)
if self.is_half: if self.is_half:
input_values = input_values.half() input_values = input_values.half()
forward_start = time.perf_counter() h2d_ms = (time.perf_counter() - h2d_start) * 1000.0
ssl_start = time.perf_counter()
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0
hubert_feature = outputs["last_hidden_state"].transpose(1, 2) hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
hidden_length_start = time.perf_counter()
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0
latent_start = time.perf_counter()
codes = self.vits_model.extract_latent(hubert_feature) codes = self.vits_model.extract_latent(hubert_feature)
forward_ms = (time.perf_counter() - forward_start) * 1000.0 extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0
else: else:
with self.stage_limiter.enter() as limiter_stats: with self.stage_limiter.enter() as limiter_stats:
h2d_start = time.perf_counter()
input_values = input_values_cpu.to(self.device) input_values = input_values_cpu.to(self.device)
attention_mask = attention_mask_cpu.to(self.device) attention_mask = attention_mask_cpu.to(self.device)
if self.is_half: if self.is_half:
input_values = input_values.half() input_values = input_values.half()
forward_start = time.perf_counter() h2d_ms = (time.perf_counter() - h2d_start) * 1000.0
ssl_start = time.perf_counter()
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0
hubert_feature = outputs["last_hidden_state"].transpose(1, 2) hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
hidden_length_start = time.perf_counter()
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0
latent_start = time.perf_counter()
codes = self.vits_model.extract_latent(hubert_feature) codes = self.vits_model.extract_latent(hubert_feature)
forward_ms = (time.perf_counter() - forward_start) * 1000.0 extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0
forward_ms = float(h2d_ms + ssl_forward_ms + hidden_length_ms + extract_latent_ms)
code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None)) code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None))
scatter_start = time.perf_counter() scatter_start = time.perf_counter()
@ -308,6 +348,11 @@ class PrepareRefSemanticBatchWorker:
0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0 0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0
), ),
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
"prompt_semantic_pack_ms": float(pack_ms),
"prompt_semantic_h2d_ms": float(h2d_ms),
"prompt_semantic_ssl_forward_ms": float(ssl_forward_ms),
"prompt_semantic_hidden_length_ms": float(hidden_length_ms),
"prompt_semantic_extract_latent_ms": float(extract_latent_ms),
"prompt_semantic_forward_ms": float(forward_ms), "prompt_semantic_forward_ms": float(forward_ms),
"prompt_semantic_scatter_ms": 0.0, "prompt_semantic_scatter_ms": 0.0,
"prompt_semantic_calls": 1.0, "prompt_semantic_calls": 1.0,

View File

@ -55,7 +55,7 @@ class T2SRequestState:
all_phones: torch.LongTensor all_phones: torch.LongTensor
all_bert_features: torch.Tensor all_bert_features: torch.Tensor
prompt_semantic: torch.LongTensor prompt_semantic: torch.LongTensor
refer_spec: Tuple[torch.Tensor, Optional[torch.Tensor]] refer_spec: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]
aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
raw_audio: torch.Tensor raw_audio: torch.Tensor
raw_sr: int raw_sr: int
@ -188,7 +188,11 @@ def build_request_state_from_parts(
ref_audio_bundle_ms = float(ref_audio_bundle.get("profile", {}).get("bundle_total_ms", 0.0)) ref_audio_bundle_ms = float(ref_audio_bundle.get("profile", {}).get("bundle_total_ms", 0.0))
bundle_profile = ref_audio_bundle.get("profile", {}) bundle_profile = ref_audio_bundle.get("profile", {})
prompt_semantic = ref_audio_bundle["prompt_semantic"].long() prompt_semantic = ref_audio_bundle["prompt_semantic"].long()
spec_audio, audio_16k = ref_audio_bundle["refer_spec"] refer_spec_value = ref_audio_bundle.get("refer_spec")
if refer_spec_value in [None, ()]:
spec_audio, audio_16k = None, None
else:
spec_audio, audio_16k = refer_spec_value
aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] = [] aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] = []
for aux_ref_audio_path in list(getattr(spec, "aux_ref_audio_paths", []) or []): for aux_ref_audio_path in list(getattr(spec, "aux_ref_audio_paths", []) or []):
if aux_ref_audio_path in [None, ""]: if aux_ref_audio_path in [None, ""]:
@ -323,6 +327,11 @@ def build_request_state_from_parts(
bundle_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) bundle_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
), ),
"prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), "prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
"prompt_semantic_pack_ms": float(bundle_profile.get("prompt_semantic_pack_ms", 0.0)),
"prompt_semantic_h2d_ms": float(bundle_profile.get("prompt_semantic_h2d_ms", 0.0)),
"prompt_semantic_ssl_forward_ms": float(bundle_profile.get("prompt_semantic_ssl_forward_ms", 0.0)),
"prompt_semantic_hidden_length_ms": float(bundle_profile.get("prompt_semantic_hidden_length_ms", 0.0)),
"prompt_semantic_extract_latent_ms": float(bundle_profile.get("prompt_semantic_extract_latent_ms", 0.0)),
"prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)), "prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)),
"prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)), "prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)),
"prompt_semantic_stage_slots": float(bundle_profile.get("prompt_semantic_stage_slots", 0.0)), "prompt_semantic_stage_slots": float(bundle_profile.get("prompt_semantic_stage_slots", 0.0)),
@ -331,6 +340,11 @@ def build_request_state_from_parts(
"prompt_semantic_batch_samples": float(bundle_profile.get("prompt_semantic_batch_samples", 0.0)), "prompt_semantic_batch_samples": float(bundle_profile.get("prompt_semantic_batch_samples", 0.0)),
"ref_spec_wait_ms": float(bundle_profile.get("ref_spec_wait_ms", 0.0)), "ref_spec_wait_ms": float(bundle_profile.get("ref_spec_wait_ms", 0.0)),
"ref_spec_ms": ref_spec_ms, "ref_spec_ms": ref_spec_ms,
"ref_spec_to_device_ms": float(bundle_profile.get("ref_spec_to_device_ms", 0.0)),
"ref_spec_main_resample_ms": float(bundle_profile.get("ref_spec_main_resample_ms", 0.0)),
"ref_spec_norm_ms": float(bundle_profile.get("ref_spec_norm_ms", 0.0)),
"ref_spec_spectrogram_ms": float(bundle_profile.get("ref_spec_spectrogram_ms", 0.0)),
"ref_spec_post_resample_ms": float(bundle_profile.get("ref_spec_post_resample_ms", 0.0)),
"ref_audio_bundle_ms": ref_audio_bundle_ms, "ref_audio_bundle_ms": ref_audio_bundle_ms,
"tensorize_ms": tensorize_ms, "tensorize_ms": tensorize_ms,
"total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0,
@ -352,7 +366,7 @@ def build_request_state_from_parts(
all_phones=all_phones, all_phones=all_phones,
all_bert_features=all_bert_features, all_bert_features=all_bert_features,
prompt_semantic=prompt_semantic, prompt_semantic=prompt_semantic,
refer_spec=(spec_audio, audio_16k), refer_spec=(None if spec_audio is None else (spec_audio, audio_16k)),
aux_refer_specs=aux_refer_specs, aux_refer_specs=aux_refer_specs,
raw_audio=raw_audio, raw_audio=raw_audio,
raw_sr=raw_sr, raw_sr=raw_sr,

View File

@ -21,6 +21,14 @@ class EngineRegistryBridgeFacade:
def engine_prepare_queue_owner(self): def engine_prepare_queue_owner(self):
return self.owner.engine_prepare_queue_owner return self.owner.engine_prepare_queue_owner
@property
def engine_prepare_text_queue_owner(self):
return self.owner.engine_prepare_text_queue_owner
@property
def engine_prepare_ref_spec_queue_owner(self):
return self.owner.engine_prepare_ref_spec_queue_owner
@property @property
def engine_finalize_queue_owner(self): def engine_finalize_queue_owner(self):
return self.owner.engine_finalize_queue_owner return self.owner.engine_finalize_queue_owner
@ -82,7 +90,33 @@ class EngineRegistryBridgeFacade:
return self.request_registry.snapshot() return self.request_registry.snapshot()
def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: def _snapshot_engine_prepare_state(self) -> Dict[str, Any]:
return self.engine_prepare_queue_owner.snapshot(max_request_ids=16) audio_snapshot = self.engine_prepare_queue_owner.snapshot(max_request_ids=16)
text_snapshot = self.engine_prepare_text_queue_owner.snapshot(max_request_ids=16)
ref_spec_snapshot = self.engine_prepare_ref_spec_queue_owner.snapshot(max_request_ids=16)
return {
"waiting_count": int(audio_snapshot.get("waiting_count", 0))
+ int(text_snapshot.get("waiting_count", 0))
+ int(ref_spec_snapshot.get("waiting_count", 0)),
"audio_waiting_count": int(audio_snapshot.get("waiting_count", 0)),
"text_waiting_count": int(text_snapshot.get("waiting_count", 0)),
"ref_spec_waiting_count": int(ref_spec_snapshot.get("waiting_count", 0)),
"audio_waiting_request_ids": list(audio_snapshot.get("waiting_request_ids", [])),
"text_waiting_request_ids": list(text_snapshot.get("waiting_request_ids", [])),
"ref_spec_waiting_request_ids": list(ref_spec_snapshot.get("waiting_request_ids", [])),
"peak_waiting": int(
max(
int(audio_snapshot.get("peak_waiting", 0)),
int(text_snapshot.get("peak_waiting", 0)),
int(ref_spec_snapshot.get("peak_waiting", 0)),
)
),
"total_submitted": int(audio_snapshot.get("total_submitted", 0)),
"total_completed": int(audio_snapshot.get("total_completed", 0)),
"text_total_submitted": int(text_snapshot.get("total_submitted", 0)),
"text_total_completed": int(text_snapshot.get("total_completed", 0)),
"ref_spec_total_submitted": int(ref_spec_snapshot.get("total_submitted", 0)),
"ref_spec_total_completed": int(ref_spec_snapshot.get("total_completed", 0)),
}
def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: def _snapshot_engine_finalize_state(self) -> Dict[str, Any]:
return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) return self.engine_finalize_queue_owner.snapshot(max_request_ids=16)
@ -107,6 +141,8 @@ class EngineRegistryBridgeFacade:
def _is_engine_drained(self) -> bool: def _is_engine_drained(self) -> bool:
prepare_empty = self.engine_prepare_queue_owner.is_drained() prepare_empty = self.engine_prepare_queue_owner.is_drained()
prepare_text_empty = self.engine_prepare_text_queue_owner.is_drained()
prepare_ref_spec_empty = self.engine_prepare_ref_spec_queue_owner.is_drained()
dispatch_empty = self.engine_dispatch_queue_owner.is_drained() dispatch_empty = self.engine_dispatch_queue_owner.is_drained()
finalize_empty = self.engine_finalize_queue_owner.is_drained() finalize_empty = self.engine_finalize_queue_owner.is_drained()
decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs()
@ -114,6 +150,8 @@ class EngineRegistryBridgeFacade:
worker_state = self.scheduler_worker.snapshot() worker_state = self.scheduler_worker.snapshot()
return bool( return bool(
prepare_empty prepare_empty
and prepare_text_empty
and prepare_ref_spec_empty
and dispatch_empty and dispatch_empty
and finalize_empty and finalize_empty
and decode_pending_empty and decode_pending_empty

View File

@ -110,6 +110,8 @@ class EngineCompositionBuilder:
get_micro_batch_wait_s=owner.scheduler_worker.get_micro_batch_wait_s, get_micro_batch_wait_s=owner.scheduler_worker.get_micro_batch_wait_s,
) )
owner.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") owner.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
owner.engine_prepare_text_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
owner.engine_prepare_ref_spec_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
owner.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") owner.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
owner.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched") owner.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched")
@ -119,6 +121,8 @@ class EngineCompositionBuilder:
tts=owner.tts, tts=owner.tts,
scheduler_worker=owner.scheduler_worker, scheduler_worker=owner.scheduler_worker,
prepare_queue_owner=owner.engine_prepare_queue_owner, prepare_queue_owner=owner.engine_prepare_queue_owner,
prepare_text_queue_owner=owner.engine_prepare_text_queue_owner,
prepare_ref_spec_queue_owner=owner.engine_prepare_ref_spec_queue_owner,
finalize_queue_owner=owner.engine_finalize_queue_owner, finalize_queue_owner=owner.engine_finalize_queue_owner,
dispatch_queue_owner=owner.engine_dispatch_queue_owner, dispatch_queue_owner=owner.engine_dispatch_queue_owner,
decode_runtime_owner=owner.engine_decode_runtime_owner, decode_runtime_owner=owner.engine_decode_runtime_owner,

View File

@ -129,7 +129,7 @@ class EnginePolicyArbiterController:
self.state.total_ticks += 1 self.state.total_ticks += 1
if stage == "idle": if stage == "idle":
self.state.total_idle_ticks += 1 self.state.total_idle_ticks += 1
elif stage == "prepare": elif stage in {"prepare", "prepare_audio", "prepare_text", "prepare_ref_spec"}:
self.state.total_prepare_dispatches += 1 self.state.total_prepare_dispatches += 1
self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst)
elif stage == "finalize": elif stage == "finalize":
@ -282,7 +282,11 @@ class EnginePolicyArbiterController:
request_registry = self.snapshot_request_registry() request_registry = self.snapshot_request_registry()
worker_state = self.get_worker_state() worker_state = self.get_worker_state()
policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) policy_snapshot = self.build_policy_snapshot(request_registry, worker_state)
prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0)) prepare_state = self.snapshot_prepare_state()
prepare_waiting = int(prepare_state.get("waiting_count", 0))
prepare_audio_waiting = int(prepare_state.get("audio_waiting_count", 0))
prepare_text_waiting = int(prepare_state.get("text_waiting_count", 0))
prepare_ref_spec_waiting = int(prepare_state.get("ref_spec_waiting_count", 0))
finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0))
decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0))
decode_runtime_state = self.snapshot_decode_runtime_state() decode_runtime_state = self.snapshot_decode_runtime_state()
@ -291,6 +295,9 @@ class EnginePolicyArbiterController:
worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0))
worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) worker_running_requests = int(decode_runtime_state.get("active_request_count", 0))
prepare_age_ms = float(self.peek_queue_age_ms("prepare")) prepare_age_ms = float(self.peek_queue_age_ms("prepare"))
prepare_audio_age_ms = float(self.peek_queue_age_ms("prepare_audio"))
prepare_text_age_ms = float(self.peek_queue_age_ms("prepare_text"))
prepare_ref_spec_age_ms = float(self.peek_queue_age_ms("prepare_ref_spec"))
finalize_age_ms = float(self.peek_queue_age_ms("finalize")) finalize_age_ms = float(self.peek_queue_age_ms("finalize"))
decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending"))
decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0))
@ -316,14 +323,31 @@ class EnginePolicyArbiterController:
and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0)
): ):
return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state
if (
finalize_waiting > 0
and prepare_ref_spec_waiting > 0
and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0)
):
return "prepare_ref_spec", "finalize_waiting_for_ref_spec", policy_snapshot, worker_state
if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0):
return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms):
return "finalize", "finalize_aging", policy_snapshot, worker_state return "finalize", "finalize_aging", policy_snapshot, worker_state
if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0):
return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms):
return "prepare_text", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0 and prepare_text_waiting <= 0:
return "prepare_ref_spec", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
return "prepare_audio", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms):
return "prepare", "prepare_aging", policy_snapshot, worker_state if prepare_text_waiting > 0 and prepare_text_age_ms >= max(prepare_audio_age_ms, prepare_age_ms - 1e-6):
return "prepare_text", "prepare_aging", policy_snapshot, worker_state
if (
prepare_ref_spec_waiting > 0
and prepare_ref_spec_age_ms >= max(prepare_audio_age_ms, prepare_text_age_ms, prepare_age_ms - 1e-6)
):
return "prepare_ref_spec", "prepare_aging", policy_snapshot, worker_state
return "prepare_audio", "prepare_aging", policy_snapshot, worker_state
if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: if worker_decode_control_enabled and worker_decode_has_work and policy_allowed:
return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state
if decode_waiting > 0 and policy_allowed: if decode_waiting > 0 and policy_allowed:
@ -331,5 +355,9 @@ class EnginePolicyArbiterController:
if finalize_waiting > 0: if finalize_waiting > 0:
return "finalize", "finalize_fallback", policy_snapshot, worker_state return "finalize", "finalize_fallback", policy_snapshot, worker_state
if prepare_waiting > 0: if prepare_waiting > 0:
return "prepare", "prepare_fallback", policy_snapshot, worker_state if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms):
return "prepare_text", "prepare_fallback", policy_snapshot, worker_state
if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0:
return "prepare_ref_spec", "prepare_fallback", policy_snapshot, worker_state
return "prepare_audio", "prepare_fallback", policy_snapshot, worker_state
return "idle", "no_pending_work", policy_snapshot, worker_state return "idle", "no_pending_work", policy_snapshot, worker_state

View File

@ -324,8 +324,24 @@ class EngineGpuPrepareTask:
done_future: asyncio.Future | None done_future: asyncio.Future | None
engine_request_id: str | None engine_request_id: str | None
enqueue_time: float enqueue_time: float
queue_wait_ms: float = 0.0 phase: str = "audio"
audio_enqueue_time: float = 0.0
audio_start_time: float = 0.0
audio_end_time: float = 0.0
text_enqueue_time: float = 0.0
text_start_time: float = 0.0
text_end_time: float = 0.0
ref_spec_enqueue_time: float = 0.0
ref_spec_start_time: float = 0.0
ref_spec_end_time: float = 0.0
audio_queue_wait_ms: float = 0.0
text_queue_wait_ms: float = 0.0
ref_spec_queue_wait_ms: float = 0.0
admission_wait_ms: float = 0.0 admission_wait_ms: float = 0.0
phase_one: Dict[str, Any] | None = None
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]] | None = None
state_result: T2SRequestState | None = None
cancelled: bool = False
error: str | None = None error: str | None = None

View File

@ -14,6 +14,8 @@ class EngineStageOrchestrator:
executor: EngineStageExecutor, executor: EngineStageExecutor,
scheduler_worker: UnifiedSchedulerWorker, scheduler_worker: UnifiedSchedulerWorker,
prepare_queue_owner: EngineTaskQueueOwner, prepare_queue_owner: EngineTaskQueueOwner,
prepare_text_queue_owner: EngineTaskQueueOwner,
prepare_ref_spec_queue_owner: EngineTaskQueueOwner,
finalize_queue_owner: EngineTaskQueueOwner, finalize_queue_owner: EngineTaskQueueOwner,
dispatch_queue_owner: EngineTaskQueueOwner, dispatch_queue_owner: EngineTaskQueueOwner,
decode_runtime_owner: EngineDecodeRuntimeOwner, decode_runtime_owner: EngineDecodeRuntimeOwner,
@ -22,6 +24,8 @@ class EngineStageOrchestrator:
self.executor = executor self.executor = executor
self.scheduler_worker = scheduler_worker self.scheduler_worker = scheduler_worker
self.prepare_queue_owner = prepare_queue_owner self.prepare_queue_owner = prepare_queue_owner
self.prepare_text_queue_owner = prepare_text_queue_owner
self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner
self.finalize_queue_owner = finalize_queue_owner self.finalize_queue_owner = finalize_queue_owner
self.dispatch_queue_owner = dispatch_queue_owner self.dispatch_queue_owner = dispatch_queue_owner
self.decode_runtime_owner = decode_runtime_owner self.decode_runtime_owner = decode_runtime_owner
@ -45,7 +49,17 @@ class EngineStageOrchestrator:
def peek_queue_age_ms(self, queue_name: str) -> float: def peek_queue_age_ms(self, queue_name: str) -> float:
if queue_name == "prepare": if queue_name == "prepare":
return max(
self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time"),
self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time"),
self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time"),
)
if queue_name == "prepare_audio":
return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time")
if queue_name == "prepare_text":
return self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time")
if queue_name == "prepare_ref_spec":
return self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time")
if queue_name == "finalize": if queue_name == "finalize":
return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time") return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time")
if queue_name == "decode_runtime_pending": if queue_name == "decode_runtime_pending":
@ -62,6 +76,10 @@ class EngineStageOrchestrator:
return True return True
if self.prepare_queue_owner.has_items(): if self.prepare_queue_owner.has_items():
return True return True
if self.prepare_text_queue_owner.has_items():
return True
if self.prepare_ref_spec_queue_owner.has_items():
return True
if self.finalize_queue_owner.has_items(): if self.finalize_queue_owner.has_items():
return True return True
return self.dispatch_queue_owner.has_items() return self.dispatch_queue_owner.has_items()
@ -79,6 +97,12 @@ class EngineStageOrchestrator:
executed = False executed = False
if stage == "prepare": if stage == "prepare":
executed = self.executor.run_engine_prepare_once() executed = self.executor.run_engine_prepare_once()
elif stage == "prepare_audio":
executed = self.executor.run_engine_prepare_audio_once()
elif stage == "prepare_text":
executed = self.executor.run_engine_prepare_text_once()
elif stage == "prepare_ref_spec":
executed = self.executor.run_engine_prepare_ref_spec_once()
elif stage == "finalize": elif stage == "finalize":
executed = self.executor.run_engine_finalize_once() executed = self.executor.run_engine_finalize_once()
elif stage == "decode_dispatch": elif stage == "decode_dispatch":

View File

@ -24,6 +24,8 @@ class EngineStageCoordinator:
tts: TTS, tts: TTS,
scheduler_worker: UnifiedSchedulerWorker, scheduler_worker: UnifiedSchedulerWorker,
prepare_queue_owner: EngineTaskQueueOwner, prepare_queue_owner: EngineTaskQueueOwner,
prepare_text_queue_owner: EngineTaskQueueOwner,
prepare_ref_spec_queue_owner: EngineTaskQueueOwner,
finalize_queue_owner: EngineTaskQueueOwner, finalize_queue_owner: EngineTaskQueueOwner,
dispatch_queue_owner: EngineTaskQueueOwner, dispatch_queue_owner: EngineTaskQueueOwner,
decode_runtime_owner: EngineDecodeRuntimeOwner, decode_runtime_owner: EngineDecodeRuntimeOwner,
@ -45,6 +47,8 @@ class EngineStageCoordinator:
tts=tts, tts=tts,
scheduler_worker=scheduler_worker, scheduler_worker=scheduler_worker,
prepare_queue_owner=prepare_queue_owner, prepare_queue_owner=prepare_queue_owner,
prepare_text_queue_owner=prepare_text_queue_owner,
prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner,
finalize_queue_owner=finalize_queue_owner, finalize_queue_owner=finalize_queue_owner,
dispatch_queue_owner=dispatch_queue_owner, dispatch_queue_owner=dispatch_queue_owner,
decode_runtime_owner=decode_runtime_owner, decode_runtime_owner=decode_runtime_owner,
@ -66,6 +70,8 @@ class EngineStageCoordinator:
executor=self.executor, executor=self.executor,
scheduler_worker=scheduler_worker, scheduler_worker=scheduler_worker,
prepare_queue_owner=prepare_queue_owner, prepare_queue_owner=prepare_queue_owner,
prepare_text_queue_owner=prepare_text_queue_owner,
prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner,
finalize_queue_owner=finalize_queue_owner, finalize_queue_owner=finalize_queue_owner,
dispatch_queue_owner=dispatch_queue_owner, dispatch_queue_owner=dispatch_queue_owner,
decode_runtime_owner=decode_runtime_owner, decode_runtime_owner=decode_runtime_owner,
@ -144,6 +150,15 @@ class EngineStageCoordinator:
def run_engine_prepare_once(self) -> bool: def run_engine_prepare_once(self) -> bool:
return self.executor.run_engine_prepare_once() return self.executor.run_engine_prepare_once()
def run_engine_prepare_audio_once(self) -> bool:
return self.executor.run_engine_prepare_audio_once()
def run_engine_prepare_text_once(self) -> bool:
return self.executor.run_engine_prepare_text_once()
def run_engine_prepare_ref_spec_once(self) -> bool:
return self.executor.run_engine_prepare_ref_spec_once()
def run_engine_finalize_once(self) -> bool: def run_engine_finalize_once(self) -> bool:
return self.executor.run_engine_finalize_once() return self.executor.run_engine_finalize_once()

View File

@ -24,6 +24,10 @@ class EngineDispatchStageMixin:
engine_request_id: str | None, engine_request_id: str | None,
timeout_sec: float | None, timeout_sec: float | None,
) -> EngineDispatchTask: ) -> EngineDispatchTask:
if float(state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0:
error = RuntimeError("ref_spec async stage failed before dispatch")
self.fail_request_state(engine_request_id or state.request_id, str(error))
raise error
task = EngineDispatchTask( task = EngineDispatchTask(
request_id=state.request_id, request_id=state.request_id,
state=state, state=state,

View File

@ -31,6 +31,8 @@ class EngineStageExecutor(
tts: TTS, tts: TTS,
scheduler_worker: UnifiedSchedulerWorker, scheduler_worker: UnifiedSchedulerWorker,
prepare_queue_owner: EngineTaskQueueOwner, prepare_queue_owner: EngineTaskQueueOwner,
prepare_text_queue_owner: EngineTaskQueueOwner,
prepare_ref_spec_queue_owner: EngineTaskQueueOwner,
finalize_queue_owner: EngineTaskQueueOwner, finalize_queue_owner: EngineTaskQueueOwner,
dispatch_queue_owner: EngineTaskQueueOwner, dispatch_queue_owner: EngineTaskQueueOwner,
decode_runtime_owner: EngineDecodeRuntimeOwner, decode_runtime_owner: EngineDecodeRuntimeOwner,
@ -51,6 +53,8 @@ class EngineStageExecutor(
self.tts = tts self.tts = tts
self.scheduler_worker = scheduler_worker self.scheduler_worker = scheduler_worker
self.prepare_queue_owner = prepare_queue_owner self.prepare_queue_owner = prepare_queue_owner
self.prepare_text_queue_owner = prepare_text_queue_owner
self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner
self.finalize_queue_owner = finalize_queue_owner self.finalize_queue_owner = finalize_queue_owner
self.dispatch_queue_owner = dispatch_queue_owner self.dispatch_queue_owner = dispatch_queue_owner
self.decode_runtime_owner = decode_runtime_owner self.decode_runtime_owner = decode_runtime_owner

View File

@ -39,10 +39,37 @@ class EngineFinalizeStageMixin:
tasks = self.take_engine_finalize_batch_nonblocking() tasks = self.take_engine_finalize_batch_nonblocking()
if not tasks: if not tasks:
return False return False
self.scheduler_worker.begin_finalize_execution(len(tasks)) ready_tasks: List[SchedulerFinalizeTask] = []
failed_tasks: List[SchedulerFinalizeTask] = []
deferred_tasks: List[SchedulerFinalizeTask] = []
for task in tasks:
job = self.get_engine_job(task.request_id)
if job is None:
continue
if float(job.state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0:
failed_tasks.append(task)
continue
if job.state.refer_spec is None:
deferred_tasks.append(task)
self.merge_request_state_profile(
job.engine_request_id or job.request_id,
{
"engine_finalize_ref_spec_blocked": 1.0,
},
)
continue
ready_tasks.append(task)
if deferred_tasks:
self.finalize_queue_owner.enqueue_many(deferred_tasks)
if failed_tasks:
self.fail_engine_jobs([task.request_id for task in failed_tasks], "ref_spec async stage failed")
if not ready_tasks:
self.finalize_queue_owner.mark_completed(len(failed_tasks), notify=True)
return False
self.scheduler_worker.begin_finalize_execution(len(ready_tasks))
try: try:
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
for task in tasks: for task in ready_tasks:
job = self.get_engine_job(task.request_id) job = self.get_engine_job(task.request_id)
if job is None: if job is None:
continue continue
@ -50,7 +77,7 @@ class EngineFinalizeStageMixin:
if not jobs_and_items: if not jobs_and_items:
return False return False
now = time.perf_counter() now = time.perf_counter()
for task in tasks: for task in ready_tasks:
job = self.get_engine_job(task.request_id) job = self.get_engine_job(task.request_id)
if job is not None: if job is not None:
job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0)
@ -69,8 +96,8 @@ class EngineFinalizeStageMixin:
for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results):
self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
except Exception as exc: except Exception as exc:
self.fail_engine_jobs([task.request_id for task in tasks], str(exc)) self.fail_engine_jobs([task.request_id for task in ready_tasks], str(exc))
finally: finally:
self.scheduler_worker.end_finalize_execution(len(tasks)) self.scheduler_worker.end_finalize_execution(len(ready_tasks))
self.finalize_queue_owner.mark_completed(len(tasks), notify=True) self.finalize_queue_owner.mark_completed(len(ready_tasks) + len(failed_tasks), notify=True)
return True return True

View File

@ -10,6 +10,13 @@ from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineGpuPrepare
class EnginePrepareStageMixin: class EnginePrepareStageMixin:
def _prepare_waiting_total(self) -> int:
return (
int(self.prepare_queue_owner.waiting_count())
+ int(self.prepare_text_queue_owner.waiting_count())
+ int(self.prepare_ref_spec_queue_owner.waiting_count())
)
async def _wait_prepare_queue_admission(self) -> float: async def _wait_prepare_queue_admission(self) -> float:
soft_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_SOFT_MAX", "0"))) soft_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_SOFT_MAX", "0")))
if soft_max <= 0: if soft_max <= 0:
@ -19,7 +26,7 @@ class EnginePrepareStageMixin:
float(max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_ADMISSION_POLL_MS", "1")))) / 1000.0, float(max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_ADMISSION_POLL_MS", "1")))) / 1000.0,
) )
wait_start = time.perf_counter() wait_start = time.perf_counter()
while self.prepare_queue_owner.waiting_count() >= soft_max: while self._prepare_waiting_total() >= soft_max:
await asyncio.sleep(poll_s) await asyncio.sleep(poll_s)
return max(0.0, (time.perf_counter() - wait_start) * 1000.0) return max(0.0, (time.perf_counter() - wait_start) * 1000.0)
@ -53,44 +60,247 @@ class EnginePrepareStageMixin:
done_future=done_future, done_future=done_future,
engine_request_id=engine_request_id or spec.request_id, engine_request_id=engine_request_id or spec.request_id,
enqueue_time=time.perf_counter(), enqueue_time=time.perf_counter(),
phase="audio",
audio_enqueue_time=time.perf_counter(),
admission_wait_ms=float(prepare_queue_admission_wait_ms), admission_wait_ms=float(prepare_queue_admission_wait_ms),
) )
self.prepare_queue_owner.enqueue(task) self.prepare_queue_owner.enqueue(task)
self.notify_arbiter() self.notify_arbiter()
return await done_future return await done_future
def run_engine_prepare_once(self) -> bool: def _should_chain_prepare_text_after_audio(self) -> bool:
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() if str(os.environ.get("GPTSOVITS_ENGINE_PREPARE_CHAIN_TEXT", "1")).strip().lower() in {"0", "false", "no", "off"}:
tasks = self.prepare_queue_owner.pop_left_many(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) return False
if self.finalize_queue_owner.has_items() or self.dispatch_queue_owner.has_items():
return False
decode_runtime_state = self.snapshot_engine_decode_runtime_state()
if bool(decode_runtime_state.get("has_work", False)):
return False
return True
def _maybe_apply_ref_spec_to_state(self, task: EngineGpuPrepareTask) -> None:
if task.state_result is None or task.ref_spec_result is None:
return
self.scheduler_worker.apply_ref_spec_result_to_state(task.state_result, task.ref_spec_result)
if task.engine_request_id not in [None, ""]:
self.merge_request_state_profile(
str(task.engine_request_id),
{
"engine_prepare_ref_spec_queue_wait_ms": float(task.ref_spec_queue_wait_ms),
"ref_spec_wait_ms": float(task.ref_spec_result[1].get("ref_spec_wait_ms", 0.0)),
"ref_spec_ms": float(task.ref_spec_result[1].get("ref_spec_ms", 0.0)),
"ref_spec_to_device_ms": float(task.ref_spec_result[1].get("ref_spec_to_device_ms", 0.0)),
"ref_spec_main_resample_ms": float(task.ref_spec_result[1].get("ref_spec_main_resample_ms", 0.0)),
"ref_spec_norm_ms": float(task.ref_spec_result[1].get("ref_spec_norm_ms", 0.0)),
"ref_spec_spectrogram_ms": float(task.ref_spec_result[1].get("ref_spec_spectrogram_ms", 0.0)),
"ref_spec_post_resample_ms": float(task.ref_spec_result[1].get("ref_spec_post_resample_ms", 0.0)),
},
)
def _mark_ref_spec_async_failed(
self,
task: EngineGpuPrepareTask,
error: Exception,
*,
queue_wait_ms: float,
) -> None:
task.error = str(error)
task.cancelled = True
if task.state_result is not None:
task.state_result.prepare_profile["ref_spec_async_failed"] = 1.0
task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms)
if task.engine_request_id not in [None, ""]:
self.merge_request_state_profile(
str(task.engine_request_id),
{
"ref_spec_async_failed": 1.0,
"engine_prepare_ref_spec_queue_wait_ms": float(queue_wait_ms),
},
)
self.fail_request_state(task.engine_request_id or task.request_id, str(error))
self.fail_engine_jobs([task.request_id], str(error))
self.notify_arbiter()
def _run_engine_prepare_audio_once(self, batch_max_items: int) -> bool:
tasks = self.prepare_queue_owner.pop_left_many(batch_max_items)
if not tasks: if not tasks:
return False return False
now = time.perf_counter() now = time.perf_counter()
queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks] queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks]
batch_results = asyncio.run( for task in tasks:
self.scheduler_worker.prepare_gpu_stages_profiled_async([task.cpu_stage for task in tasks]) task.audio_start_time = float(now)
) batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_audio_phases_async([task.cpu_stage for task in tasks]))
completed_count = 0 completed_count = 0
for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results): for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results):
task.audio_end_time = time.perf_counter()
if isinstance(result, Exception): if isinstance(result, Exception):
task.error = str(result) task.error = str(result)
self.fail_request_state(task.engine_request_id or task.request_id, str(result)) self.fail_request_state(task.engine_request_id or task.request_id, str(result))
self._notify_prepare_error(task, result) self._notify_prepare_error(task, result)
completed_count += 1 completed_count += 1
continue continue
state, prepare_exec_started_at, prepare_exec_finished_at = result task.audio_queue_wait_ms = float(queue_wait_ms)
state.prepare_profile["engine_prepare_queue_admission_wait_ms"] = float(task.admission_wait_ms) task.phase_one = result
state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) task.phase = "text"
task.enqueue_time = time.perf_counter()
task.text_enqueue_time = float(task.enqueue_time)
task.ref_spec_enqueue_time = float(task.enqueue_time)
self.prepare_text_queue_owner.enqueue(task)
self.prepare_ref_spec_queue_owner.enqueue(task)
if task.engine_request_id not in [None, ""]:
self.merge_request_state_profile(
str(task.engine_request_id),
{
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
"engine_prepare_audio_queue_wait_ms": float(queue_wait_ms),
"engine_prepare_audio_batch_size": float(len(tasks)),
"engine_prepare_audio_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)),
"engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time),
"engine_prepare_audio_start_ts": float(task.audio_start_time),
"engine_prepare_audio_end_ts": float(task.audio_end_time),
"engine_prepare_text_enqueue_ts": float(task.text_enqueue_time),
"engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time),
},
)
completed_count += 1
self.prepare_queue_owner.mark_completed(completed_count)
if completed_count > 0 and self._should_chain_prepare_text_after_audio():
self._run_engine_prepare_text_once(min(batch_max_items, completed_count))
return True
if completed_count > 0:
self.notify_arbiter()
return True
def _run_engine_prepare_text_once(self, batch_max_items: int) -> bool:
tasks = self.prepare_text_queue_owner.pop_left_many(batch_max_items)
if not tasks:
return False
now = time.perf_counter()
queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks]
for task in tasks:
task.text_start_time = float(now)
items = [(task.cpu_stage, task.phase_one) for task in tasks if task.phase_one is not None]
batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_text_phases_async(items))
completed_count = 0
for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results):
task.text_end_time = time.perf_counter()
if isinstance(result, Exception):
task.error = str(result)
task.cancelled = True
self.fail_request_state(task.engine_request_id or task.request_id, str(result))
self._notify_prepare_error(task, result)
completed_count += 1
continue
task.text_queue_wait_ms = float(queue_wait_ms)
state, prepare_exec_started_at, prepare_exec_finished_at = self.scheduler_worker.build_gpu_prepare_result_from_phases(
task.cpu_stage,
task.phase_one or {},
result,
extra_profile={
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
"engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms),
"engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms),
"engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms),
"engine_prepare_audio_batch_size": float(len(tasks)),
"engine_prepare_text_batch_size": float(len(tasks)),
"engine_prepare_audio_phase_mode": 2.0,
"engine_prepare_audio_phase_wall_ms": float((task.phase_one or {}).get("phase_wall_ms", 0.0)),
"engine_prepare_text_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)),
"engine_prepare_text_phase_batch_size": float(len(tasks)),
"engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time),
"engine_prepare_audio_start_ts": float(task.audio_start_time),
"engine_prepare_audio_end_ts": float(task.audio_end_time),
"engine_prepare_text_enqueue_ts": float(task.text_enqueue_time),
"engine_prepare_text_start_ts": float(task.text_start_time),
"engine_prepare_text_end_ts": float(task.text_end_time),
"engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time),
},
)
task.state_result = state
self._maybe_apply_ref_spec_to_state(task)
state.prepare_profile["engine_gpu_prepare_batch_size"] = float(len(tasks)) state.prepare_profile["engine_gpu_prepare_batch_size"] = float(len(tasks))
if task.engine_request_id not in [None, ""]: if task.engine_request_id not in [None, ""]:
self.merge_request_state_profile( self.merge_request_state_profile(
str(task.engine_request_id), str(task.engine_request_id),
{ {
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms), "engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms), "engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms),
"engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms),
"engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms),
"engine_gpu_prepare_batch_size": float(len(tasks)), "engine_gpu_prepare_batch_size": float(len(tasks)),
}, },
) )
self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at))
completed_count += 1 completed_count += 1
self.prepare_queue_owner.mark_completed(completed_count) self.prepare_text_queue_owner.mark_completed(completed_count)
return True return True
def _run_engine_prepare_ref_spec_once(self, batch_max_items: int) -> bool:
tasks = self.prepare_ref_spec_queue_owner.pop_left_many(batch_max_items)
if not tasks:
return False
now = time.perf_counter()
runnable_tasks: list[EngineGpuPrepareTask] = []
queue_wait_ms_list: list[float] = []
completed_count = 0
for task in tasks:
if task.cancelled or task.phase_one is None:
completed_count += 1
continue
task.ref_spec_start_time = float(now)
runnable_tasks.append(task)
queue_wait_ms_list.append(max(0.0, (now - task.ref_spec_enqueue_time) * 1000.0))
if not runnable_tasks:
self.prepare_ref_spec_queue_owner.mark_completed(completed_count)
return True
batch_results = asyncio.run(
self.scheduler_worker.prepare_ref_spec_stages_async([task.phase_one or {} for task in runnable_tasks])
)
for task, queue_wait_ms, result in zip(runnable_tasks, queue_wait_ms_list, batch_results):
task.ref_spec_end_time = time.perf_counter()
task.ref_spec_queue_wait_ms = float(queue_wait_ms)
if isinstance(result, Exception):
self._mark_ref_spec_async_failed(task, result, queue_wait_ms=float(queue_wait_ms))
completed_count += 1
continue
task.ref_spec_result = result
self._maybe_apply_ref_spec_to_state(task)
if task.state_result is not None:
task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms)
task.state_result.prepare_profile["engine_prepare_ref_spec_enqueue_ts"] = float(task.ref_spec_enqueue_time)
task.state_result.prepare_profile["engine_prepare_ref_spec_start_ts"] = float(task.ref_spec_start_time)
task.state_result.prepare_profile["engine_prepare_ref_spec_end_ts"] = float(task.ref_spec_end_time)
completed_count += 1
self.prepare_ref_spec_queue_owner.mark_completed(completed_count)
return True
def run_engine_prepare_once(self) -> bool:
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
batch_max_items = int(prepare_batch_policy.get("prepare_batch_max_items", 1))
audio_age_ms = self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time")
text_age_ms = self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time")
if self.prepare_text_queue_owner.has_items() and (
not self.prepare_queue_owner.has_items() or text_age_ms >= audio_age_ms
):
return self._run_engine_prepare_text_once(batch_max_items)
if self.prepare_queue_owner.has_items():
return self._run_engine_prepare_audio_once(batch_max_items)
if self.prepare_ref_spec_queue_owner.has_items():
return self._run_engine_prepare_ref_spec_once(batch_max_items)
if self.prepare_text_queue_owner.has_items():
return self._run_engine_prepare_text_once(batch_max_items)
if self.prepare_ref_spec_queue_owner.has_items():
return self._run_engine_prepare_ref_spec_once(batch_max_items)
return False
def run_engine_prepare_audio_once(self) -> bool:
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
return self._run_engine_prepare_audio_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))
def run_engine_prepare_text_once(self) -> bool:
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
return self._run_engine_prepare_text_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))
def run_engine_prepare_ref_spec_once(self) -> bool:
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
return self._run_engine_prepare_ref_spec_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))

View File

@ -151,7 +151,9 @@ class WorkerFinalizeExecutor:
@staticmethod @staticmethod
def _collect_job_refer_specs(job: SchedulerPendingJob) -> List[tuple]: def _collect_job_refer_specs(job: SchedulerPendingJob) -> List[tuple]:
refer_specs = [job.state.refer_spec] refer_specs = []
if job.state.refer_spec is not None:
refer_specs.append(job.state.refer_spec)
refer_specs.extend(list(getattr(job.state, "aux_refer_specs", []) or [])) refer_specs.extend(list(getattr(job.state, "aux_refer_specs", []) or []))
return refer_specs return refer_specs

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
import os import os
import time import time
from typing import Callable, Dict, List from typing import Any, Callable, Dict, List
from GPT_SoVITS.TTS_infer_pack.TTS import TTS from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage
@ -81,11 +81,60 @@ class WorkerPrepareExecutor:
cpu_stages: List[PreparedCpuStage], cpu_stages: List[PreparedCpuStage],
) -> List[tuple[T2SRequestState, float, float] | Exception]: ) -> List[tuple[T2SRequestState, float, float] | Exception]:
try: try:
return list( return await self.coordinator.prepare_gpu_stages_profiled_async(cpu_stages)
await asyncio.gather( finally:
*[self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) for cpu_stage in cpu_stages], self._notify_state_change()
return_exceptions=True,
) async def prepare_gpu_audio_phases_async(
self,
cpu_stages: List[PreparedCpuStage],
) -> List[Dict[str, Any] | Exception]:
try:
return await self.coordinator.prepare_gpu_audio_phases_async(cpu_stages)
finally:
self._notify_state_change()
async def prepare_gpu_text_phases_async(
self,
items: List[tuple[PreparedCpuStage, Dict[str, Any]]],
) -> List[Dict[str, Any] | Exception]:
try:
return await self.coordinator.prepare_gpu_text_phases_async(items)
finally:
self._notify_state_change()
def build_gpu_prepare_result_from_phases(
self,
cpu_stage: PreparedCpuStage,
phase_one: Dict[str, Any],
phase_two: Dict[str, Any],
extra_profile: Dict[str, float] | None = None,
) -> tuple[T2SRequestState, float, float]:
try:
return self.coordinator.build_gpu_prepare_result_from_phases(
cpu_stage,
phase_one,
phase_two,
extra_profile=extra_profile,
) )
finally: finally:
self._notify_state_change() self._notify_state_change()
async def prepare_ref_spec_stages_async(
self,
phase_ones: List[Dict[str, Any]],
) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]:
try:
return await self.coordinator.prepare_ref_spec_stages_async(phase_ones)
finally:
self._notify_state_change()
def apply_ref_spec_result_to_state(
self,
state: T2SRequestState,
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]],
) -> None:
try:
self.coordinator.apply_ref_spec_result_to_state(state, ref_spec_result)
finally:
self._notify_state_change()

View File

@ -267,3 +267,42 @@ class WorkerSubmitLifecycleMixin:
cpu_stages: List[PreparedCpuStage], cpu_stages: List[PreparedCpuStage],
) -> List[tuple[T2SRequestState, float, float] | Exception]: ) -> List[tuple[T2SRequestState, float, float] | Exception]:
return await self.prepare_executor.prepare_gpu_stages_profiled_async(cpu_stages) return await self.prepare_executor.prepare_gpu_stages_profiled_async(cpu_stages)
async def prepare_gpu_audio_phases_async(
self,
cpu_stages: List[PreparedCpuStage],
) -> List[Dict[str, Any] | Exception]:
return await self.prepare_executor.prepare_gpu_audio_phases_async(cpu_stages)
async def prepare_gpu_text_phases_async(
self,
items: List[tuple[PreparedCpuStage, Dict[str, Any]]],
) -> List[Dict[str, Any] | Exception]:
return await self.prepare_executor.prepare_gpu_text_phases_async(items)
def build_gpu_prepare_result_from_phases(
self,
cpu_stage: PreparedCpuStage,
phase_one: Dict[str, Any],
phase_two: Dict[str, Any],
extra_profile: Dict[str, float] | None = None,
) -> tuple[T2SRequestState, float, float]:
return self.prepare_executor.build_gpu_prepare_result_from_phases(
cpu_stage,
phase_one,
phase_two,
extra_profile=extra_profile,
)
async def prepare_ref_spec_stages_async(
self,
phase_ones: List[Dict[str, Any]],
) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]:
return await self.prepare_executor.prepare_ref_spec_stages_async(phase_ones)
def apply_ref_spec_result_to_state(
self,
state: T2SRequestState,
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]],
) -> None:
self.prepare_executor.apply_ref_spec_result_to_state(state, ref_spec_result)

View File

@ -244,6 +244,16 @@ class G2PWRuntimeWrapper:
) )
self.batch_worker.start() self.batch_worker.start()
def _sync_runtime_env_overrides(self) -> None:
os.environ["G2PW_ENABLE_CUDA_GRAPH"] = "1" if self.enable_cuda_graph else "0"
os.environ["G2PW_ENABLE_PROFILE"] = "1" if self.enable_profiling else "0"
os.environ["G2PW_DUMP_GRAPH_CACHE_STATS"] = "1" if self.dump_graph_cache_stats else "0"
os.environ["G2PW_FULL_GRAPH_CACHE_LIMIT"] = str(int(self.full_graph_cache_limit))
os.environ["G2PW_TAIL_GRAPH_CACHE_LIMIT"] = str(int(self.tail_graph_cache_limit))
os.environ["G2PW_ALLOW_TENSOR_CORES"] = "1" if self.allow_tensor_cores else "0"
os.environ["G2PW_USE_CUBLASLT_BIAS_EPILOGUE"] = "1" if self.use_cublaslt_bias_epilogue else "0"
os.environ["G2PW_GEMM_PRECISION"] = {0: "fp32", 1: "fp16", 2: "bf16"}.get(int(self.gemm_precision), "fp32")
def _destroy_handle(self) -> None: def _destroy_handle(self) -> None:
if self.handle: if self.handle:
self.lib.g2pw_runtime_destroy(self.handle) self.lib.g2pw_runtime_destroy(self.handle)
@ -268,6 +278,7 @@ class G2PWRuntimeWrapper:
return "" if not message else message.decode("utf-8", errors="replace") return "" if not message else message.decode("utf-8", errors="replace")
def _create_handle(self, batch_size: int, seq_len: int) -> None: def _create_handle(self, batch_size: int, seq_len: int) -> None:
self._sync_runtime_env_overrides()
new_handle = self.lib.g2pw_runtime_create( new_handle = self.lib.g2pw_runtime_create(
str(self.manifest_path).encode("utf-8"), str(self.manifest_path).encode("utf-8"),
str(self.weights_path).encode("utf-8"), str(self.weights_path).encode("utf-8"),
@ -518,6 +529,10 @@ class G2PWRuntimeWrapper:
return { return {
"shard_index": int(self.shard_index), "shard_index": int(self.shard_index),
"enabled": bool(self.batch_enabled), "enabled": bool(self.batch_enabled),
"enable_cuda_graph": bool(self.enable_cuda_graph),
"enable_profiling": bool(self.enable_profiling),
"full_graph_cache_limit": int(self.full_graph_cache_limit),
"tail_graph_cache_limit": int(self.tail_graph_cache_limit),
"window_ms": float(self.batch_window_s * 1000.0), "window_ms": float(self.batch_window_s * 1000.0),
"max_requests": int(self.batch_max_requests), "max_requests": int(self.batch_max_requests),
"max_rows": int(self.batch_max_rows), "max_rows": int(self.batch_max_rows),