From 8a444c10b72aadeb1f8be9210b45f2a519d383e4 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Fri, 13 Mar 2026 16:45:38 +0800 Subject: [PATCH] Enhance TTS processing with new reference specification handling and profiling metrics Refactor the PrepareCoordinator and related components to improve the handling of reference specifications in the TTS system. Introduce new methods for building and extracting reference prompts and specifications, along with detailed profiling metrics for performance monitoring. Update the PrepareRefSemanticBatchWorker to include additional timing metrics and caching mechanisms for resampling. These changes enhance the efficiency and maintainability of the TTS framework, particularly in managing audio processing and reference data. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 23 +- .../TTS_infer_pack/prepare_coordinator.py | 537 +++++++++++++----- .../prepare_ref_semantic_batch_worker.py | 61 +- GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 20 +- .../unified_engine_bridge_registry.py | 40 +- .../TTS_infer_pack/unified_engine_builder.py | 4 + .../unified_engine_component_policy.py | 38 +- .../unified_engine_component_runtime.py | 18 +- .../unified_engine_orchestration.py | 24 + .../TTS_infer_pack/unified_engine_stage.py | 15 + .../unified_engine_stage_dispatch.py | 4 + .../unified_engine_stage_executor.py | 4 + .../unified_engine_stage_finalize.py | 39 +- .../unified_engine_stage_prepare.py | 234 +++++++- .../unified_engine_worker_finalize.py | 4 +- .../unified_engine_worker_prepare.py | 61 +- .../unified_engine_worker_submit.py | 39 ++ GPT_SoVITS/text/g2pw/cuda_api.py | 15 + 18 files changed, 1006 insertions(+), 174 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 78bf7178..16bc8db8 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -996,21 +996,39 @@ class TTS: return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr) def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + spec, audio, _, _, _ = self._extract_ref_spec_profile_from_raw(raw_audio, raw_sr) + return spec, audio, raw_audio, raw_sr + + def _extract_ref_spec_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + profile = { + "ref_spec_to_device_ms": 0.0, + "ref_spec_main_resample_ms": 0.0, + "ref_spec_norm_ms": 0.0, + "ref_spec_spectrogram_ms": 0.0, + "ref_spec_post_resample_ms": 0.0, + } + to_device_start = time.perf_counter() raw_audio_device = raw_audio.to(self.configs.device).float() + profile["ref_spec_to_device_ms"] = (time.perf_counter() - to_device_start) * 1000.0 if raw_sr != self.configs.sampling_rate: + resample_start = time.perf_counter() audio = raw_audio_device if audio.shape[0] == 2: audio = audio.mean(0).unsqueeze(0) audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device) + profile["ref_spec_main_resample_ms"] = (time.perf_counter() - resample_start) * 1000.0 else: audio = raw_audio_device if audio.shape[0] == 2: audio = audio.mean(0).unsqueeze(0) + norm_start = time.perf_counter() maxx = audio.abs().max() if maxx > 1: audio /= min(2, maxx) + profile["ref_spec_norm_ms"] = (time.perf_counter() - norm_start) * 1000.0 + spec_start = time.perf_counter() spec = spectrogram_torch( audio, self.configs.filter_length, @@ -1019,15 +1037,18 @@ class TTS: self.configs.win_length, center=False, ) + profile["ref_spec_spectrogram_ms"] = (time.perf_counter() - spec_start) * 1000.0 if self.configs.is_half: spec = spec.half() if self.is_v2pro == True: + post_resample_start = time.perf_counter() audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device) + profile["ref_spec_post_resample_ms"] = (time.perf_counter() - post_resample_start) * 1000.0 if self.configs.is_half: audio = audio.half() else: audio = None - return spec, audio, raw_audio, raw_sr + return spec, audio, raw_audio, raw_sr, profile def extract_ref_spec(self, ref_audio_path: str): raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py index 65bcbb51..e74e0de4 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py @@ -228,8 +228,74 @@ class PrepareCoordinator: def _load_ref_audio_raw(self, ref_audio_path: str): return self.tts._load_ref_audio_raw(ref_audio_path) + def _build_ref_prompt_semantic_from_raw(self, raw_audio, raw_sr: int): + load_profile = {"audio_load_ms": 0.0} + if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None: + prompt_semantic, worker_profile = self.tts.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr) + return { + "prompt_semantic": prompt_semantic, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + **load_profile, + "audio_stage_wait_ms": float(worker_profile.get("prompt_semantic_wait_ms", 0.0)), + "audio_stage_slots": float(worker_profile.get("prompt_semantic_stage_slots", 0.0)), + "audio_stage_inflight_peak": float(worker_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + "prompt_semantic_ms": float( + worker_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + + worker_profile.get("prompt_semantic_forward_ms", 0.0) + + worker_profile.get("prompt_semantic_scatter_ms", 0.0) + ), + **{key: float(value) for key, value in worker_profile.items()}, + "ref_spec_wait_ms": 0.0, + "ref_spec_ms": 0.0, + "bundle_total_ms": float(worker_profile.get("prompt_semantic_wait_ms", 0.0)) + + float(worker_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(worker_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(worker_profile.get("prompt_semantic_scatter_ms", 0.0)), + }, + } + wav16k, cpu_prepare_ms, limiter_stats = self.tts._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr) + with self.tts.prepare_ref_audio_stage_limiter.enter() as stage_stats: + prompt_semantic, forward_ms = self.tts._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k) + return { + "prompt_semantic": prompt_semantic, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": 0.0, + "audio_stage_wait_ms": float(stage_stats.get("wait_ms", 0.0)), + "audio_stage_slots": float(stage_stats.get("slots", 0.0)), + "audio_stage_inflight_peak": float(stage_stats.get("peak_inflight", 0.0)), + "prompt_semantic_wait_ms": float(stage_stats.get("wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_wait_ms": float(limiter_stats.get("wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_slots": float(limiter_stats.get("slots", 0.0)), + "prompt_semantic_cpu_prepare_inflight_peak": float(limiter_stats.get("peak_inflight", 0.0)), + "prompt_semantic_worker_queue_wait_ms": 0.0, + "prompt_semantic_batch_collect_wait_ms": 0.0, + "prompt_semantic_stage_limiter_wait_ms": float(stage_stats.get("wait_ms", 0.0)), + "prompt_semantic_batch_dispatch_delay_ms": 0.0, + "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), + "prompt_semantic_pack_ms": 0.0, + "prompt_semantic_h2d_ms": 0.0, + "prompt_semantic_ssl_forward_ms": 0.0, + "prompt_semantic_hidden_length_ms": 0.0, + "prompt_semantic_extract_latent_ms": 0.0, + "prompt_semantic_forward_ms": float(forward_ms), + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_stage_slots": float(stage_stats.get("slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float(stage_stats.get("peak_inflight", 0.0)), + "prompt_semantic_batch_size": 1.0, + "prompt_semantic_batch_samples": 0.0, + "ref_spec_wait_ms": 0.0, + "ref_spec_ms": 0.0, + "bundle_total_ms": float(cpu_prepare_ms + forward_ms + stage_stats.get("wait_ms", 0.0)), + }, + } + def _extract_ref_spec_from_raw(self, raw_audio, raw_sr: int): - return self.tts._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + spec, audio, _, _, profile = self.tts._extract_ref_spec_profile_from_raw(raw_audio, raw_sr) + return (spec, audio), profile @staticmethod def _build_empty_text_features_like(reference: PreparedTextFeatures | None = None) -> PreparedTextFeatures: @@ -523,7 +589,7 @@ class PrepareCoordinator: finally: self.text_feature_gate.release() - async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult: + async def _run_ref_prompt_semantic_stage(self, ref_audio_path: str) -> ProfiledResult: if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None: submit_at = time.perf_counter() started_at = float(submit_at) @@ -538,19 +604,7 @@ class PrepareCoordinator: prompt_semantic_task = asyncio.create_task( self.tts.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr) ) - await self.ref_spec_gate.acquire() - try: - ref_spec_task = asyncio.create_task( - self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr) - ) - (prompt_semantic, prompt_semantic_profile), ref_spec_profiled = await asyncio.gather( - prompt_semantic_task, - ref_spec_task, - ) - finally: - self.ref_spec_gate.release() - - refer_spec = ref_spec_profiled.result + prompt_semantic, prompt_semantic_profile = await prompt_semantic_task limiter_snapshot = ( self.tts.prepare_ref_audio_stage_limiter.snapshot() if getattr(self.tts, "prepare_ref_audio_stage_limiter", None) is not None @@ -561,21 +615,15 @@ class PrepareCoordinator: + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) ) - audio_stage_wait_ms = ( - float(load_profiled.queue_ms) - + float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) - + float(ref_spec_profiled.queue_ms) - ) finished_at = time.perf_counter() result = { "prompt_semantic": prompt_semantic, - "refer_spec": refer_spec, "raw_audio": raw_audio, "raw_sr": raw_sr, "profile": { "audio_load_queue_ms": float(load_profiled.queue_ms), "audio_load_ms": float(load_profiled.run_ms), - "audio_stage_wait_ms": float(audio_stage_wait_ms), + "audio_stage_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), "audio_stage_slots": float( max( float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), @@ -593,6 +641,17 @@ class PrepareCoordinator: "prompt_semantic_cpu_prepare_ms": float( prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) ), + "prompt_semantic_pack_ms": float(prompt_semantic_profile.get("prompt_semantic_pack_ms", 0.0)), + "prompt_semantic_h2d_ms": float(prompt_semantic_profile.get("prompt_semantic_h2d_ms", 0.0)), + "prompt_semantic_ssl_forward_ms": float( + prompt_semantic_profile.get("prompt_semantic_ssl_forward_ms", 0.0) + ), + "prompt_semantic_hidden_length_ms": float( + prompt_semantic_profile.get("prompt_semantic_hidden_length_ms", 0.0) + ), + "prompt_semantic_extract_latent_ms": float( + prompt_semantic_profile.get("prompt_semantic_extract_latent_ms", 0.0) + ), "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), @@ -603,14 +662,10 @@ class PrepareCoordinator: "prompt_semantic_batch_samples": float( prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) ), - "ref_spec_wait_ms": float(ref_spec_profiled.queue_ms), - "ref_spec_ms": float(ref_spec_profiled.run_ms), "bundle_total_ms": float( load_profiled.queue_ms + load_profiled.run_ms + prompt_semantic_ms - + ref_spec_profiled.queue_ms - + ref_spec_profiled.run_ms ), }, } @@ -623,21 +678,26 @@ class PrepareCoordinator: await self.ref_audio_gate.acquire() try: - if hasattr(self.tts, "extract_ref_audio_bundle_async"): - submit_at = time.perf_counter() - started_at = time.perf_counter() - result = await self.tts.extract_ref_audio_bundle_async(ref_audio_path) - finished_at = time.perf_counter() - return ProfiledResult( - result=result, - submit_at=float(submit_at), - started_at=float(started_at), - finished_at=float(finished_at), - ) - return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path) + load_profiled = await self._run_on_executor(self.ref_audio_executor, self._load_ref_audio_raw, ref_audio_path) + raw_audio, raw_sr = load_profiled.result + submit_at = time.perf_counter() + started_at = time.perf_counter() + result = await asyncio.to_thread(self._build_ref_prompt_semantic_from_raw, raw_audio, raw_sr) + result.setdefault("profile", {}) + result["profile"]["audio_load_queue_ms"] = float(load_profiled.queue_ms) + result["profile"]["audio_load_ms"] = float(load_profiled.run_ms) + finished_at = time.perf_counter() + return ProfiledResult(result=result, submit_at=float(submit_at), started_at=float(started_at), finished_at=float(finished_at)) finally: self.ref_audio_gate.release() + async def _run_ref_spec_stage(self, raw_audio, raw_sr: int) -> ProfiledResult: + await self.ref_spec_gate.acquire() + try: + return await self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr) + finally: + self.ref_spec_gate.release() + def _release_split_stage_slot(self) -> None: self._mark_leave() self._inflight_gate.release() @@ -682,101 +742,318 @@ class PrepareCoordinator: cpu_stage: PreparedCpuStage, ) -> tuple[T2SRequestState, float, float]: try: - g2pw_pair_start = time.perf_counter() - g2pw_pair_task = asyncio.create_task( - self._run_g2pw_pair_stage( - cpu_stage.prompt_cpu_profiled.result, - cpu_stage.target_cpu_profiled.result, - ) - ) - ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(cpu_stage.spec.ref_audio_path))) - (prompt_g2pw_profiled, target_g2pw_profiled), ref_audio_profiled = await asyncio.gather( - g2pw_pair_task, - ref_audio_task, - ) - g2pw_pair_end = time.perf_counter() - text_pair_start = time.perf_counter() - text_feature_pair_task = asyncio.create_task( - self._run_text_feature_pair_stage( - prompt_g2pw_profiled.result, - target_g2pw_profiled.result, - cpu_stage.prompt_cpu_profiled.run_ms, - cpu_stage.target_cpu_profiled.run_ms, - prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}), - target_base_profile=dict(target_g2pw_profiled.profile or {}), - ) - ) - prompt_feature_profiled, target_feature_profiled = await text_feature_pair_task - text_pair_end = time.perf_counter() - state = build_request_state_from_parts( - tts=self.tts, - spec=cpu_stage.spec, - prompt_text=cpu_stage.prompt_text, - text=cpu_stage.text, - prompt_result=prompt_feature_profiled.result, - target_result=target_feature_profiled.result, - ref_audio_bundle=ref_audio_profiled.result, - prepare_start=cpu_stage.prepare_start, - prepare_sync_start=cpu_stage.prepare_start, - profile_overrides={ - "executor_queue_ms": max(0.0, (cpu_stage.prepare_start - cpu_stage.prepare_submit_at) * 1000.0), - "prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms, - "executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0), - "text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0), - "g2pw_pair_ms": max(0.0, (g2pw_pair_end - g2pw_pair_start) * 1000.0), - "prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms, - "prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms, - "prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), - "prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), - "prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), - "text_g2pw_queue_ms": target_g2pw_profiled.queue_ms, - "text_g2pw_run_ms": target_g2pw_profiled.run_ms, - "text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), - "text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), - "text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), - "prompt_text_parallel_future_wait_ms": 0.0, - "prompt_text_parallel_future_executor_queue_ms": 0.0, - "prompt_text_parallel_future_run_ms": 0.0, - "prompt_text_parallel_future_finish_after_submit_ms": 0.0, - "prompt_text_parallel_future_queue_tail_after_target_ms": 0.0, - "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, - "prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms, - "prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms, - "prompt_text_cpu_admission_wait_ms": float( - (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) - ), - "prompt_text_cpu_backpressure_wait_ms": float( - (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) - ), - "prompt_text_cpu_capacity_wait_ms": float( - (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) - ), - "prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms, - "prompt_text_feature_run_ms": prompt_feature_profiled.run_ms, - "text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms, - "text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms, - "text_cpu_admission_wait_ms": float( - (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) - ), - "text_cpu_backpressure_wait_ms": float( - (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) - ), - "text_cpu_capacity_wait_ms": float( - (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) - ), - "text_feature_queue_ms": target_feature_profiled.queue_ms, - "text_feature_run_ms": target_feature_profiled.run_ms, - "ref_audio_task_queue_ms": ref_audio_profiled.queue_ms, - "ref_audio_task_run_ms": ref_audio_profiled.run_ms, - "worker_prepare_inflight_on_enter": float(cpu_stage.current_inflight), - "worker_prepare_peak_inflight": float(cpu_stage.peak_inflight), + phase_one = await self._prepare_gpu_phase_one(cpu_stage) + phase_two = await self._prepare_gpu_phase_two(cpu_stage, phase_one) + return self._build_gpu_prepare_result( + cpu_stage, + phase_one, + phase_two, + extra_profile={ + "engine_prepare_audio_phase_mode": 0.0, + "engine_prepare_audio_phase_wall_ms": float(phase_one["phase_wall_ms"]), + "engine_prepare_audio_phase_batch_size": 1.0, + "engine_prepare_text_phase_wall_ms": float(phase_two["phase_wall_ms"]), + "engine_prepare_text_phase_batch_size": 1.0, }, ) - prepare_exec_finished_at = time.perf_counter() - state.prepare_profile["executor_run_wall_ms"] = max( - 0.0, (prepare_exec_finished_at - cpu_stage.prepare_start) * 1000.0 + finally: + self._release_split_stage_slot() + + async def _prepare_gpu_phase_one(self, cpu_stage: PreparedCpuStage) -> Dict[str, Any]: + phase_start = time.perf_counter() + g2pw_pair_task = asyncio.create_task( + self._run_g2pw_pair_stage( + cpu_stage.prompt_cpu_profiled.result, + cpu_stage.target_cpu_profiled.result, ) - return state, cpu_stage.prepare_start, prepare_exec_finished_at + ) + ref_audio_task = asyncio.create_task(self._run_ref_prompt_semantic_stage(str(cpu_stage.spec.ref_audio_path))) + prompt_g2pw_profiled, target_g2pw_profiled = await g2pw_pair_task + g2pw_pair_end = time.perf_counter() + ref_audio_profiled = await ref_audio_task + phase_end = time.perf_counter() + return { + "prompt_g2pw_profiled": prompt_g2pw_profiled, + "target_g2pw_profiled": target_g2pw_profiled, + "ref_audio_profiled": ref_audio_profiled, + "ref_spec_result": None, + "g2pw_pair_ms": max(0.0, (g2pw_pair_end - phase_start) * 1000.0), + "phase_wall_ms": max(0.0, (phase_end - phase_start) * 1000.0), + } + + async def _prepare_gpu_phase_two( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + ) -> Dict[str, Any]: + phase_start = time.perf_counter() + prompt_g2pw_profiled = phase_one["prompt_g2pw_profiled"] + target_g2pw_profiled = phase_one["target_g2pw_profiled"] + prompt_feature_profiled, target_feature_profiled = await self._run_text_feature_pair_stage( + prompt_g2pw_profiled.result, + target_g2pw_profiled.result, + cpu_stage.prompt_cpu_profiled.run_ms, + cpu_stage.target_cpu_profiled.run_ms, + prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}), + target_base_profile=dict(target_g2pw_profiled.profile or {}), + ) + phase_end = time.perf_counter() + return { + "prompt_feature_profiled": prompt_feature_profiled, + "target_feature_profiled": target_feature_profiled, + "phase_wall_ms": max(0.0, (phase_end - phase_start) * 1000.0), + } + + def _build_gpu_prepare_result( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + phase_two: Dict[str, Any], + extra_profile: Dict[str, float] | None = None, + ) -> tuple[T2SRequestState, float, float]: + prompt_g2pw_profiled = phase_one["prompt_g2pw_profiled"] + target_g2pw_profiled = phase_one["target_g2pw_profiled"] + ref_audio_profiled = phase_one["ref_audio_profiled"] + ref_spec_result = phase_one.get("ref_spec_result") + prompt_feature_profiled = phase_two["prompt_feature_profiled"] + target_feature_profiled = phase_two["target_feature_profiled"] + profile_overrides = { + "executor_queue_ms": max(0.0, (cpu_stage.prepare_start - cpu_stage.prepare_submit_at) * 1000.0), + "prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms, + "prepare_submit_ts": float(cpu_stage.prepare_submit_at), + "prepare_cpu_start_ts": float(cpu_stage.prepare_start), + "prepare_cpu_done_ts": float( + max(cpu_stage.prompt_cpu_profiled.finished_at, cpu_stage.target_cpu_profiled.finished_at) + ), + "prompt_text_cpu_start_ts": float(cpu_stage.prompt_cpu_profiled.started_at), + "prompt_text_cpu_end_ts": float(cpu_stage.prompt_cpu_profiled.finished_at), + "text_cpu_start_ts": float(cpu_stage.target_cpu_profiled.started_at), + "text_cpu_end_ts": float(cpu_stage.target_cpu_profiled.finished_at), + "executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0), + "text_feature_pair_ms": float(phase_two["phase_wall_ms"]), + "g2pw_pair_ms": float(phase_one["g2pw_pair_ms"]), + "prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms, + "prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms, + "prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), + "prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), + "prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), + "text_g2pw_queue_ms": target_g2pw_profiled.queue_ms, + "text_g2pw_run_ms": target_g2pw_profiled.run_ms, + "text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), + "text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), + "text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), + "prompt_text_parallel_future_wait_ms": 0.0, + "prompt_text_parallel_future_executor_queue_ms": 0.0, + "prompt_text_parallel_future_run_ms": 0.0, + "prompt_text_parallel_future_finish_after_submit_ms": 0.0, + "prompt_text_parallel_future_queue_tail_after_target_ms": 0.0, + "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, + "prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms, + "prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms, + "prompt_text_cpu_admission_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) + ), + "prompt_text_cpu_backpressure_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) + ), + "prompt_text_cpu_capacity_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) + ), + "prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms, + "prompt_text_feature_run_ms": prompt_feature_profiled.run_ms, + "text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms, + "text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms, + "text_cpu_admission_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) + ), + "text_cpu_backpressure_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) + ), + "text_cpu_capacity_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) + ), + "text_feature_queue_ms": target_feature_profiled.queue_ms, + "text_feature_run_ms": target_feature_profiled.run_ms, + "ref_audio_task_queue_ms": ref_audio_profiled.queue_ms, + "ref_audio_task_run_ms": ref_audio_profiled.run_ms, + "worker_prepare_inflight_on_enter": float(cpu_stage.current_inflight), + "worker_prepare_peak_inflight": float(cpu_stage.peak_inflight), + } + if extra_profile: + profile_overrides.update({key: float(value) for key, value in extra_profile.items()}) + ref_audio_bundle = dict(ref_audio_profiled.result) + ref_audio_profile = dict(ref_audio_bundle.get("profile", {})) + if ref_spec_result is not None: + refer_spec, ref_spec_profiled = ref_spec_result + ref_audio_bundle["refer_spec"] = refer_spec + ref_audio_profile.update( + { + "ref_spec_wait_ms": float(ref_spec_profiled.get("ref_spec_wait_ms", 0.0)), + "ref_spec_ms": float(ref_spec_profiled.get("ref_spec_ms", 0.0)), + "ref_spec_to_device_ms": float(ref_spec_profiled.get("ref_spec_to_device_ms", 0.0)), + "ref_spec_main_resample_ms": float(ref_spec_profiled.get("ref_spec_main_resample_ms", 0.0)), + "ref_spec_norm_ms": float(ref_spec_profiled.get("ref_spec_norm_ms", 0.0)), + "ref_spec_spectrogram_ms": float(ref_spec_profiled.get("ref_spec_spectrogram_ms", 0.0)), + "ref_spec_post_resample_ms": float(ref_spec_profiled.get("ref_spec_post_resample_ms", 0.0)), + } + ) + else: + ref_audio_bundle["refer_spec"] = None + ref_audio_profile.setdefault("ref_spec_wait_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_to_device_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_main_resample_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_norm_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_spectrogram_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_post_resample_ms", 0.0) + ref_audio_bundle["profile"] = ref_audio_profile + state = build_request_state_from_parts( + tts=self.tts, + spec=cpu_stage.spec, + prompt_text=cpu_stage.prompt_text, + text=cpu_stage.text, + prompt_result=prompt_feature_profiled.result, + target_result=target_feature_profiled.result, + ref_audio_bundle=ref_audio_bundle, + prepare_start=cpu_stage.prepare_start, + prepare_sync_start=cpu_stage.prepare_start, + profile_overrides=profile_overrides, + ) + prepare_exec_finished_at = time.perf_counter() + state.prepare_profile["executor_run_wall_ms"] = max(0.0, (prepare_exec_finished_at - cpu_stage.prepare_start) * 1000.0) + return state, cpu_stage.prepare_start, prepare_exec_finished_at + + async def prepare_ref_spec_stages_async( + self, + phase_ones: list[Dict[str, Any]], + ) -> list[tuple[tuple[Any, Any], Dict[str, float]] | Exception]: + async def _one(phase_one: Dict[str, Any]): + ref_audio_profiled = phase_one["ref_audio_profiled"] + raw_audio = ref_audio_profiled.result["raw_audio"] + raw_sr = int(ref_audio_profiled.result["raw_sr"]) + profiled = await self._run_ref_spec_stage(raw_audio, raw_sr) + refer_spec, profile = profiled.result + merged_profile = dict(profile) + merged_profile["ref_spec_wait_ms"] = float(profiled.queue_ms) + merged_profile["ref_spec_ms"] = float(profiled.run_ms) + return refer_spec, merged_profile + + if not phase_ones: + return [] + return list(await asyncio.gather(*[_one(phase_one) for phase_one in phase_ones], return_exceptions=True)) + + def apply_ref_spec_result_to_state( + self, + state: T2SRequestState, + ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]], + ) -> None: + refer_spec, profile = ref_spec_result + state.refer_spec = refer_spec + state.prepare_profile["ref_spec_wait_ms"] = float(profile.get("ref_spec_wait_ms", 0.0)) + state.prepare_profile["ref_spec_ms"] = float(profile.get("ref_spec_ms", 0.0)) + state.prepare_profile["ref_spec_to_device_ms"] = float(profile.get("ref_spec_to_device_ms", 0.0)) + state.prepare_profile["ref_spec_main_resample_ms"] = float(profile.get("ref_spec_main_resample_ms", 0.0)) + state.prepare_profile["ref_spec_norm_ms"] = float(profile.get("ref_spec_norm_ms", 0.0)) + state.prepare_profile["ref_spec_spectrogram_ms"] = float(profile.get("ref_spec_spectrogram_ms", 0.0)) + state.prepare_profile["ref_spec_post_resample_ms"] = float(profile.get("ref_spec_post_resample_ms", 0.0)) + + async def prepare_gpu_stages_profiled_async( + self, + cpu_stages: list[PreparedCpuStage], + ) -> list[tuple[T2SRequestState, float, float] | Exception]: + if not cpu_stages: + return [] + if len(cpu_stages) == 1: + single_stage = cpu_stages[0] + try: + return [await self.prepare_gpu_stage_profiled_async(single_stage)] + except Exception as exc: # noqa: PERF203 + return [exc] + + phase_one_started_at = time.perf_counter() + phase_one_results = await asyncio.gather( + *[self._prepare_gpu_phase_one(cpu_stage) for cpu_stage in cpu_stages], + return_exceptions=True, + ) + phase_one_finished_at = time.perf_counter() + phase_one_wall_ms = max(0.0, (phase_one_finished_at - phase_one_started_at) * 1000.0) + + outputs: list[tuple[T2SRequestState, float, float] | Exception | None] = [None] * len(cpu_stages) + pending_phase_two: list[tuple[int, PreparedCpuStage, Dict[str, Any]]] = [] + for index, (cpu_stage, phase_one) in enumerate(zip(cpu_stages, phase_one_results)): + if isinstance(phase_one, Exception): + outputs[index] = phase_one + self._release_split_stage_slot() + continue + pending_phase_two.append((index, cpu_stage, phase_one)) + + phase_two_started_at = time.perf_counter() + phase_two_results = await asyncio.gather( + *[self._prepare_gpu_phase_two(cpu_stage, phase_one) for _, cpu_stage, phase_one in pending_phase_two], + return_exceptions=True, + ) + phase_two_finished_at = time.perf_counter() + phase_two_wall_ms = max(0.0, (phase_two_finished_at - phase_two_started_at) * 1000.0) + + for (index, cpu_stage, phase_one), phase_two in zip(pending_phase_two, phase_two_results): + try: + if isinstance(phase_two, Exception): + outputs[index] = phase_two + continue + outputs[index] = self._build_gpu_prepare_result( + cpu_stage, + phase_one, + phase_two, + extra_profile={ + "engine_prepare_audio_phase_mode": 1.0, + "engine_prepare_audio_phase_wall_ms": float(phase_one_wall_ms), + "engine_prepare_audio_phase_batch_size": float(len(cpu_stages)), + "engine_prepare_text_phase_wall_ms": float(phase_two_wall_ms), + "engine_prepare_text_phase_batch_size": float(len(pending_phase_two)), + }, + ) + except Exception as exc: # noqa: PERF203 + outputs[index] = exc + finally: + self._release_split_stage_slot() + + return [item if item is not None else RuntimeError("prepare batch result missing") for item in outputs] + + async def prepare_gpu_audio_phases_async( + self, + cpu_stages: list[PreparedCpuStage], + ) -> list[Dict[str, Any] | Exception]: + if not cpu_stages: + return [] + return list( + await asyncio.gather( + *[self._prepare_gpu_phase_one(cpu_stage) for cpu_stage in cpu_stages], + return_exceptions=True, + ) + ) + + async def prepare_gpu_text_phases_async( + self, + items: list[tuple[PreparedCpuStage, Dict[str, Any]]], + ) -> list[Dict[str, Any] | Exception]: + if not items: + return [] + return list( + await asyncio.gather( + *[self._prepare_gpu_phase_two(cpu_stage, phase_one) for cpu_stage, phase_one in items], + return_exceptions=True, + ) + ) + + def build_gpu_prepare_result_from_phases( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + phase_two: Dict[str, Any], + extra_profile: Dict[str, float] | None = None, + ) -> tuple[T2SRequestState, float, float]: + try: + return self._build_gpu_prepare_result(cpu_stage, phase_one, phase_two, extra_profile=extra_profile) finally: self._release_split_stage_slot() diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py index ff5591b2..4628a2a2 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py @@ -15,6 +15,7 @@ REF_AUDIO_MIN_SAMPLES_16K = 48000 REF_AUDIO_MAX_SAMPLES_16K = 160000 _RESAMPLE_CACHE_LOCK = threading.Lock() _RESAMPLE_CACHE: Dict[Tuple[int, int, str], torchaudio.transforms.Resample] = {} +_RESAMPLE_STREAM_CACHE: Dict[str, torch.cuda.Stream] = {} def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.transforms.Resample: @@ -28,6 +29,16 @@ def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.tran return transform +def _get_resample_stream(device: str) -> torch.cuda.Stream: + device_key = str(device) + with _RESAMPLE_CACHE_LOCK: + stream = _RESAMPLE_STREAM_CACHE.get(device_key) + if stream is None: + stream = torch.cuda.Stream(device=device_key) + _RESAMPLE_STREAM_CACHE[device_key] = stream + return stream + + def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor: resample_device = os.environ.get("GPTSOVITS_PREPARE_REF_RESAMPLE_DEVICE", "cpu").strip().lower() or "cpu" if resample_device not in {"cpu", "cuda"}: @@ -37,10 +48,20 @@ def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wa wav_mono = raw_audio if wav_mono.dim() == 2 and wav_mono.shape[0] != 1: wav_mono = wav_mono.mean(0, keepdim=True) - wav16k = wav_mono.to(dtype=torch.float32, device=resample_device) - if raw_sr != 16000: - wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k) - wav16k = wav16k.squeeze(0).contiguous() + if resample_device == "cuda": + stream = _get_resample_stream(resample_device) + with torch.cuda.stream(stream): + wav16k = wav_mono.to(dtype=torch.float32, device=resample_device) + if raw_sr != 16000: + wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k) + wav16k = wav16k.squeeze(0).contiguous() + stream.synchronize() + wav16k = wav16k.detach().to(device="cpu", dtype=torch.float32).contiguous() + else: + wav16k = wav_mono.to(dtype=torch.float32, device=resample_device) + if raw_sr != 16000: + wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k) + wav16k = wav16k.squeeze(0).contiguous() if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K: raise OSError("参考音频在3~10秒范围外,请更换!") if zero_wav_samples > 0: @@ -256,37 +277,56 @@ class PrepareRefSemanticBatchWorker: batch_samples = int(wav_lengths.sum().item()) max_wav_len = int(wav_lengths.max().item()) + pack_start = time.perf_counter() input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32) attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long) for batch_index, wav in enumerate(prepared_wavs): wav_len = int(wav.shape[0]) input_values_cpu[batch_index, :wav_len] = wav attention_mask_cpu[batch_index, :wav_len] = 1 + pack_ms = (time.perf_counter() - pack_start) * 1000.0 limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} + h2d_ms = 0.0 + ssl_forward_ms = 0.0 + hidden_length_ms = 0.0 + extract_latent_ms = 0.0 if self.stage_limiter is None: + h2d_start = time.perf_counter() input_values = input_values_cpu.to(self.device) attention_mask = attention_mask_cpu.to(self.device) if self.is_half: input_values = input_values.half() - forward_start = time.perf_counter() + h2d_ms = (time.perf_counter() - h2d_start) * 1000.0 + ssl_start = time.perf_counter() outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) + ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0 hubert_feature = outputs["last_hidden_state"].transpose(1, 2) + hidden_length_start = time.perf_counter() hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) + hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0 + latent_start = time.perf_counter() codes = self.vits_model.extract_latent(hubert_feature) - forward_ms = (time.perf_counter() - forward_start) * 1000.0 + extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0 else: with self.stage_limiter.enter() as limiter_stats: + h2d_start = time.perf_counter() input_values = input_values_cpu.to(self.device) attention_mask = attention_mask_cpu.to(self.device) if self.is_half: input_values = input_values.half() - forward_start = time.perf_counter() + h2d_ms = (time.perf_counter() - h2d_start) * 1000.0 + ssl_start = time.perf_counter() outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) + ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0 hubert_feature = outputs["last_hidden_state"].transpose(1, 2) + hidden_length_start = time.perf_counter() hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) + hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0 + latent_start = time.perf_counter() codes = self.vits_model.extract_latent(hubert_feature) - forward_ms = (time.perf_counter() - forward_start) * 1000.0 + extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0 + forward_ms = float(h2d_ms + ssl_forward_ms + hidden_length_ms + extract_latent_ms) code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None)) scatter_start = time.perf_counter() @@ -308,6 +348,11 @@ class PrepareRefSemanticBatchWorker: 0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0 ), "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), + "prompt_semantic_pack_ms": float(pack_ms), + "prompt_semantic_h2d_ms": float(h2d_ms), + "prompt_semantic_ssl_forward_ms": float(ssl_forward_ms), + "prompt_semantic_hidden_length_ms": float(hidden_length_ms), + "prompt_semantic_extract_latent_ms": float(extract_latent_ms), "prompt_semantic_forward_ms": float(forward_ms), "prompt_semantic_scatter_ms": 0.0, "prompt_semantic_calls": 1.0, diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index 73e2a2c7..a4d462d0 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -55,7 +55,7 @@ class T2SRequestState: all_phones: torch.LongTensor all_bert_features: torch.Tensor prompt_semantic: torch.LongTensor - refer_spec: Tuple[torch.Tensor, Optional[torch.Tensor]] + refer_spec: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] raw_audio: torch.Tensor raw_sr: int @@ -188,7 +188,11 @@ def build_request_state_from_parts( ref_audio_bundle_ms = float(ref_audio_bundle.get("profile", {}).get("bundle_total_ms", 0.0)) bundle_profile = ref_audio_bundle.get("profile", {}) prompt_semantic = ref_audio_bundle["prompt_semantic"].long() - spec_audio, audio_16k = ref_audio_bundle["refer_spec"] + refer_spec_value = ref_audio_bundle.get("refer_spec") + if refer_spec_value in [None, ()]: + spec_audio, audio_16k = None, None + else: + spec_audio, audio_16k = refer_spec_value aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] = [] for aux_ref_audio_path in list(getattr(spec, "aux_ref_audio_paths", []) or []): if aux_ref_audio_path in [None, ""]: @@ -323,6 +327,11 @@ def build_request_state_from_parts( bundle_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) ), "prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "prompt_semantic_pack_ms": float(bundle_profile.get("prompt_semantic_pack_ms", 0.0)), + "prompt_semantic_h2d_ms": float(bundle_profile.get("prompt_semantic_h2d_ms", 0.0)), + "prompt_semantic_ssl_forward_ms": float(bundle_profile.get("prompt_semantic_ssl_forward_ms", 0.0)), + "prompt_semantic_hidden_length_ms": float(bundle_profile.get("prompt_semantic_hidden_length_ms", 0.0)), + "prompt_semantic_extract_latent_ms": float(bundle_profile.get("prompt_semantic_extract_latent_ms", 0.0)), "prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)), "prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)), "prompt_semantic_stage_slots": float(bundle_profile.get("prompt_semantic_stage_slots", 0.0)), @@ -331,6 +340,11 @@ def build_request_state_from_parts( "prompt_semantic_batch_samples": float(bundle_profile.get("prompt_semantic_batch_samples", 0.0)), "ref_spec_wait_ms": float(bundle_profile.get("ref_spec_wait_ms", 0.0)), "ref_spec_ms": ref_spec_ms, + "ref_spec_to_device_ms": float(bundle_profile.get("ref_spec_to_device_ms", 0.0)), + "ref_spec_main_resample_ms": float(bundle_profile.get("ref_spec_main_resample_ms", 0.0)), + "ref_spec_norm_ms": float(bundle_profile.get("ref_spec_norm_ms", 0.0)), + "ref_spec_spectrogram_ms": float(bundle_profile.get("ref_spec_spectrogram_ms", 0.0)), + "ref_spec_post_resample_ms": float(bundle_profile.get("ref_spec_post_resample_ms", 0.0)), "ref_audio_bundle_ms": ref_audio_bundle_ms, "tensorize_ms": tensorize_ms, "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, @@ -352,7 +366,7 @@ def build_request_state_from_parts( all_phones=all_phones, all_bert_features=all_bert_features, prompt_semantic=prompt_semantic, - refer_spec=(spec_audio, audio_16k), + refer_spec=(None if spec_audio is None else (spec_audio, audio_16k)), aux_refer_specs=aux_refer_specs, raw_audio=raw_audio, raw_sr=raw_sr, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py index 88b8cc5d..f07250e1 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py @@ -21,6 +21,14 @@ class EngineRegistryBridgeFacade: def engine_prepare_queue_owner(self): return self.owner.engine_prepare_queue_owner + @property + def engine_prepare_text_queue_owner(self): + return self.owner.engine_prepare_text_queue_owner + + @property + def engine_prepare_ref_spec_queue_owner(self): + return self.owner.engine_prepare_ref_spec_queue_owner + @property def engine_finalize_queue_owner(self): return self.owner.engine_finalize_queue_owner @@ -82,7 +90,33 @@ class EngineRegistryBridgeFacade: return self.request_registry.snapshot() def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: - return self.engine_prepare_queue_owner.snapshot(max_request_ids=16) + audio_snapshot = self.engine_prepare_queue_owner.snapshot(max_request_ids=16) + text_snapshot = self.engine_prepare_text_queue_owner.snapshot(max_request_ids=16) + ref_spec_snapshot = self.engine_prepare_ref_spec_queue_owner.snapshot(max_request_ids=16) + return { + "waiting_count": int(audio_snapshot.get("waiting_count", 0)) + + int(text_snapshot.get("waiting_count", 0)) + + int(ref_spec_snapshot.get("waiting_count", 0)), + "audio_waiting_count": int(audio_snapshot.get("waiting_count", 0)), + "text_waiting_count": int(text_snapshot.get("waiting_count", 0)), + "ref_spec_waiting_count": int(ref_spec_snapshot.get("waiting_count", 0)), + "audio_waiting_request_ids": list(audio_snapshot.get("waiting_request_ids", [])), + "text_waiting_request_ids": list(text_snapshot.get("waiting_request_ids", [])), + "ref_spec_waiting_request_ids": list(ref_spec_snapshot.get("waiting_request_ids", [])), + "peak_waiting": int( + max( + int(audio_snapshot.get("peak_waiting", 0)), + int(text_snapshot.get("peak_waiting", 0)), + int(ref_spec_snapshot.get("peak_waiting", 0)), + ) + ), + "total_submitted": int(audio_snapshot.get("total_submitted", 0)), + "total_completed": int(audio_snapshot.get("total_completed", 0)), + "text_total_submitted": int(text_snapshot.get("total_submitted", 0)), + "text_total_completed": int(text_snapshot.get("total_completed", 0)), + "ref_spec_total_submitted": int(ref_spec_snapshot.get("total_submitted", 0)), + "ref_spec_total_completed": int(ref_spec_snapshot.get("total_completed", 0)), + } def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) @@ -107,6 +141,8 @@ class EngineRegistryBridgeFacade: def _is_engine_drained(self) -> bool: prepare_empty = self.engine_prepare_queue_owner.is_drained() + prepare_text_empty = self.engine_prepare_text_queue_owner.is_drained() + prepare_ref_spec_empty = self.engine_prepare_ref_spec_queue_owner.is_drained() dispatch_empty = self.engine_dispatch_queue_owner.is_drained() finalize_empty = self.engine_finalize_queue_owner.is_drained() decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() @@ -114,6 +150,8 @@ class EngineRegistryBridgeFacade: worker_state = self.scheduler_worker.snapshot() return bool( prepare_empty + and prepare_text_empty + and prepare_ref_spec_empty and dispatch_empty and finalize_empty and decode_pending_empty diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py index 0e93c442..2cb1e175 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py @@ -110,6 +110,8 @@ class EngineCompositionBuilder: get_micro_batch_wait_s=owner.scheduler_worker.get_micro_batch_wait_s, ) owner.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + owner.engine_prepare_text_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + owner.engine_prepare_ref_spec_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") owner.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") owner.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched") @@ -119,6 +121,8 @@ class EngineCompositionBuilder: tts=owner.tts, scheduler_worker=owner.scheduler_worker, prepare_queue_owner=owner.engine_prepare_queue_owner, + prepare_text_queue_owner=owner.engine_prepare_text_queue_owner, + prepare_ref_spec_queue_owner=owner.engine_prepare_ref_spec_queue_owner, finalize_queue_owner=owner.engine_finalize_queue_owner, dispatch_queue_owner=owner.engine_dispatch_queue_owner, decode_runtime_owner=owner.engine_decode_runtime_owner, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py index b6c5ca4d..65953dd9 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py @@ -129,7 +129,7 @@ class EnginePolicyArbiterController: self.state.total_ticks += 1 if stage == "idle": self.state.total_idle_ticks += 1 - elif stage == "prepare": + elif stage in {"prepare", "prepare_audio", "prepare_text", "prepare_ref_spec"}: self.state.total_prepare_dispatches += 1 self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) elif stage == "finalize": @@ -282,7 +282,11 @@ class EnginePolicyArbiterController: request_registry = self.snapshot_request_registry() worker_state = self.get_worker_state() policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) - prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0)) + prepare_state = self.snapshot_prepare_state() + prepare_waiting = int(prepare_state.get("waiting_count", 0)) + prepare_audio_waiting = int(prepare_state.get("audio_waiting_count", 0)) + prepare_text_waiting = int(prepare_state.get("text_waiting_count", 0)) + prepare_ref_spec_waiting = int(prepare_state.get("ref_spec_waiting_count", 0)) finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) decode_runtime_state = self.snapshot_decode_runtime_state() @@ -291,6 +295,9 @@ class EnginePolicyArbiterController: worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) prepare_age_ms = float(self.peek_queue_age_ms("prepare")) + prepare_audio_age_ms = float(self.peek_queue_age_ms("prepare_audio")) + prepare_text_age_ms = float(self.peek_queue_age_ms("prepare_text")) + prepare_ref_spec_age_ms = float(self.peek_queue_age_ms("prepare_ref_spec")) finalize_age_ms = float(self.peek_queue_age_ms("finalize")) decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) @@ -316,14 +323,31 @@ class EnginePolicyArbiterController: and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) ): return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state + if ( + finalize_waiting > 0 + and prepare_ref_spec_waiting > 0 + and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0) + ): + return "prepare_ref_spec", "finalize_waiting_for_ref_spec", policy_snapshot, worker_state if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): return "finalize", "finalize_aging", policy_snapshot, worker_state if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): - return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms): + return "prepare_text", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0 and prepare_text_waiting <= 0: + return "prepare_ref_spec", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + return "prepare_audio", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): - return "prepare", "prepare_aging", policy_snapshot, worker_state + if prepare_text_waiting > 0 and prepare_text_age_ms >= max(prepare_audio_age_ms, prepare_age_ms - 1e-6): + return "prepare_text", "prepare_aging", policy_snapshot, worker_state + if ( + prepare_ref_spec_waiting > 0 + and prepare_ref_spec_age_ms >= max(prepare_audio_age_ms, prepare_text_age_ms, prepare_age_ms - 1e-6) + ): + return "prepare_ref_spec", "prepare_aging", policy_snapshot, worker_state + return "prepare_audio", "prepare_aging", policy_snapshot, worker_state if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state if decode_waiting > 0 and policy_allowed: @@ -331,5 +355,9 @@ class EnginePolicyArbiterController: if finalize_waiting > 0: return "finalize", "finalize_fallback", policy_snapshot, worker_state if prepare_waiting > 0: - return "prepare", "prepare_fallback", policy_snapshot, worker_state + if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms): + return "prepare_text", "prepare_fallback", policy_snapshot, worker_state + if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0: + return "prepare_ref_spec", "prepare_fallback", policy_snapshot, worker_state + return "prepare_audio", "prepare_fallback", policy_snapshot, worker_state return "idle", "no_pending_work", policy_snapshot, worker_state diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py index 15eedeca..600e1e83 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py @@ -324,8 +324,24 @@ class EngineGpuPrepareTask: done_future: asyncio.Future | None engine_request_id: str | None enqueue_time: float - queue_wait_ms: float = 0.0 + phase: str = "audio" + audio_enqueue_time: float = 0.0 + audio_start_time: float = 0.0 + audio_end_time: float = 0.0 + text_enqueue_time: float = 0.0 + text_start_time: float = 0.0 + text_end_time: float = 0.0 + ref_spec_enqueue_time: float = 0.0 + ref_spec_start_time: float = 0.0 + ref_spec_end_time: float = 0.0 + audio_queue_wait_ms: float = 0.0 + text_queue_wait_ms: float = 0.0 + ref_spec_queue_wait_ms: float = 0.0 admission_wait_ms: float = 0.0 + phase_one: Dict[str, Any] | None = None + ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]] | None = None + state_result: T2SRequestState | None = None + cancelled: bool = False error: str | None = None diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py index a71f7e4e..0c73616f 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py @@ -14,6 +14,8 @@ class EngineStageOrchestrator: executor: EngineStageExecutor, scheduler_worker: UnifiedSchedulerWorker, prepare_queue_owner: EngineTaskQueueOwner, + prepare_text_queue_owner: EngineTaskQueueOwner, + prepare_ref_spec_queue_owner: EngineTaskQueueOwner, finalize_queue_owner: EngineTaskQueueOwner, dispatch_queue_owner: EngineTaskQueueOwner, decode_runtime_owner: EngineDecodeRuntimeOwner, @@ -22,6 +24,8 @@ class EngineStageOrchestrator: self.executor = executor self.scheduler_worker = scheduler_worker self.prepare_queue_owner = prepare_queue_owner + self.prepare_text_queue_owner = prepare_text_queue_owner + self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner self.finalize_queue_owner = finalize_queue_owner self.dispatch_queue_owner = dispatch_queue_owner self.decode_runtime_owner = decode_runtime_owner @@ -45,7 +49,17 @@ class EngineStageOrchestrator: def peek_queue_age_ms(self, queue_name: str) -> float: if queue_name == "prepare": + return max( + self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time"), + self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time"), + self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time"), + ) + if queue_name == "prepare_audio": return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") + if queue_name == "prepare_text": + return self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time") + if queue_name == "prepare_ref_spec": + return self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time") if queue_name == "finalize": return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time") if queue_name == "decode_runtime_pending": @@ -62,6 +76,10 @@ class EngineStageOrchestrator: return True if self.prepare_queue_owner.has_items(): return True + if self.prepare_text_queue_owner.has_items(): + return True + if self.prepare_ref_spec_queue_owner.has_items(): + return True if self.finalize_queue_owner.has_items(): return True return self.dispatch_queue_owner.has_items() @@ -79,6 +97,12 @@ class EngineStageOrchestrator: executed = False if stage == "prepare": executed = self.executor.run_engine_prepare_once() + elif stage == "prepare_audio": + executed = self.executor.run_engine_prepare_audio_once() + elif stage == "prepare_text": + executed = self.executor.run_engine_prepare_text_once() + elif stage == "prepare_ref_spec": + executed = self.executor.run_engine_prepare_ref_spec_once() elif stage == "finalize": executed = self.executor.run_engine_finalize_once() elif stage == "decode_dispatch": diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py index 1b872dfa..27ed3bf5 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py @@ -24,6 +24,8 @@ class EngineStageCoordinator: tts: TTS, scheduler_worker: UnifiedSchedulerWorker, prepare_queue_owner: EngineTaskQueueOwner, + prepare_text_queue_owner: EngineTaskQueueOwner, + prepare_ref_spec_queue_owner: EngineTaskQueueOwner, finalize_queue_owner: EngineTaskQueueOwner, dispatch_queue_owner: EngineTaskQueueOwner, decode_runtime_owner: EngineDecodeRuntimeOwner, @@ -45,6 +47,8 @@ class EngineStageCoordinator: tts=tts, scheduler_worker=scheduler_worker, prepare_queue_owner=prepare_queue_owner, + prepare_text_queue_owner=prepare_text_queue_owner, + prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner, finalize_queue_owner=finalize_queue_owner, dispatch_queue_owner=dispatch_queue_owner, decode_runtime_owner=decode_runtime_owner, @@ -66,6 +70,8 @@ class EngineStageCoordinator: executor=self.executor, scheduler_worker=scheduler_worker, prepare_queue_owner=prepare_queue_owner, + prepare_text_queue_owner=prepare_text_queue_owner, + prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner, finalize_queue_owner=finalize_queue_owner, dispatch_queue_owner=dispatch_queue_owner, decode_runtime_owner=decode_runtime_owner, @@ -144,6 +150,15 @@ class EngineStageCoordinator: def run_engine_prepare_once(self) -> bool: return self.executor.run_engine_prepare_once() + def run_engine_prepare_audio_once(self) -> bool: + return self.executor.run_engine_prepare_audio_once() + + def run_engine_prepare_text_once(self) -> bool: + return self.executor.run_engine_prepare_text_once() + + def run_engine_prepare_ref_spec_once(self) -> bool: + return self.executor.run_engine_prepare_ref_spec_once() + def run_engine_finalize_once(self) -> bool: return self.executor.run_engine_finalize_once() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py index 644c35f6..f6a249fa 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py @@ -24,6 +24,10 @@ class EngineDispatchStageMixin: engine_request_id: str | None, timeout_sec: float | None, ) -> EngineDispatchTask: + if float(state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0: + error = RuntimeError("ref_spec async stage failed before dispatch") + self.fail_request_state(engine_request_id or state.request_id, str(error)) + raise error task = EngineDispatchTask( request_id=state.request_id, state=state, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py index 01921d51..6d06f0c6 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py @@ -31,6 +31,8 @@ class EngineStageExecutor( tts: TTS, scheduler_worker: UnifiedSchedulerWorker, prepare_queue_owner: EngineTaskQueueOwner, + prepare_text_queue_owner: EngineTaskQueueOwner, + prepare_ref_spec_queue_owner: EngineTaskQueueOwner, finalize_queue_owner: EngineTaskQueueOwner, dispatch_queue_owner: EngineTaskQueueOwner, decode_runtime_owner: EngineDecodeRuntimeOwner, @@ -51,6 +53,8 @@ class EngineStageExecutor( self.tts = tts self.scheduler_worker = scheduler_worker self.prepare_queue_owner = prepare_queue_owner + self.prepare_text_queue_owner = prepare_text_queue_owner + self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner self.finalize_queue_owner = finalize_queue_owner self.dispatch_queue_owner = dispatch_queue_owner self.decode_runtime_owner = decode_runtime_owner diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py index 8e66f76e..4b61993e 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py @@ -39,10 +39,37 @@ class EngineFinalizeStageMixin: tasks = self.take_engine_finalize_batch_nonblocking() if not tasks: return False - self.scheduler_worker.begin_finalize_execution(len(tasks)) + ready_tasks: List[SchedulerFinalizeTask] = [] + failed_tasks: List[SchedulerFinalizeTask] = [] + deferred_tasks: List[SchedulerFinalizeTask] = [] + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is None: + continue + if float(job.state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0: + failed_tasks.append(task) + continue + if job.state.refer_spec is None: + deferred_tasks.append(task) + self.merge_request_state_profile( + job.engine_request_id or job.request_id, + { + "engine_finalize_ref_spec_blocked": 1.0, + }, + ) + continue + ready_tasks.append(task) + if deferred_tasks: + self.finalize_queue_owner.enqueue_many(deferred_tasks) + if failed_tasks: + self.fail_engine_jobs([task.request_id for task in failed_tasks], "ref_spec async stage failed") + if not ready_tasks: + self.finalize_queue_owner.mark_completed(len(failed_tasks), notify=True) + return False + self.scheduler_worker.begin_finalize_execution(len(ready_tasks)) try: jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] - for task in tasks: + for task in ready_tasks: job = self.get_engine_job(task.request_id) if job is None: continue @@ -50,7 +77,7 @@ class EngineFinalizeStageMixin: if not jobs_and_items: return False now = time.perf_counter() - for task in tasks: + for task in ready_tasks: job = self.get_engine_job(task.request_id) if job is not None: job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) @@ -69,8 +96,8 @@ class EngineFinalizeStageMixin: for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) except Exception as exc: - self.fail_engine_jobs([task.request_id for task in tasks], str(exc)) + self.fail_engine_jobs([task.request_id for task in ready_tasks], str(exc)) finally: - self.scheduler_worker.end_finalize_execution(len(tasks)) - self.finalize_queue_owner.mark_completed(len(tasks), notify=True) + self.scheduler_worker.end_finalize_execution(len(ready_tasks)) + self.finalize_queue_owner.mark_completed(len(ready_tasks) + len(failed_tasks), notify=True) return True diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py index b9095d2c..1ea7c45b 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py @@ -10,6 +10,13 @@ from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineGpuPrepare class EnginePrepareStageMixin: + def _prepare_waiting_total(self) -> int: + return ( + int(self.prepare_queue_owner.waiting_count()) + + int(self.prepare_text_queue_owner.waiting_count()) + + int(self.prepare_ref_spec_queue_owner.waiting_count()) + ) + async def _wait_prepare_queue_admission(self) -> float: soft_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_SOFT_MAX", "0"))) if soft_max <= 0: @@ -19,7 +26,7 @@ class EnginePrepareStageMixin: float(max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_ADMISSION_POLL_MS", "1")))) / 1000.0, ) wait_start = time.perf_counter() - while self.prepare_queue_owner.waiting_count() >= soft_max: + while self._prepare_waiting_total() >= soft_max: await asyncio.sleep(poll_s) return max(0.0, (time.perf_counter() - wait_start) * 1000.0) @@ -53,44 +60,247 @@ class EnginePrepareStageMixin: done_future=done_future, engine_request_id=engine_request_id or spec.request_id, enqueue_time=time.perf_counter(), + phase="audio", + audio_enqueue_time=time.perf_counter(), admission_wait_ms=float(prepare_queue_admission_wait_ms), ) self.prepare_queue_owner.enqueue(task) self.notify_arbiter() return await done_future - def run_engine_prepare_once(self) -> bool: - prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() - tasks = self.prepare_queue_owner.pop_left_many(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) + def _should_chain_prepare_text_after_audio(self) -> bool: + if str(os.environ.get("GPTSOVITS_ENGINE_PREPARE_CHAIN_TEXT", "1")).strip().lower() in {"0", "false", "no", "off"}: + return False + if self.finalize_queue_owner.has_items() or self.dispatch_queue_owner.has_items(): + return False + decode_runtime_state = self.snapshot_engine_decode_runtime_state() + if bool(decode_runtime_state.get("has_work", False)): + return False + return True + + def _maybe_apply_ref_spec_to_state(self, task: EngineGpuPrepareTask) -> None: + if task.state_result is None or task.ref_spec_result is None: + return + self.scheduler_worker.apply_ref_spec_result_to_state(task.state_result, task.ref_spec_result) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + { + "engine_prepare_ref_spec_queue_wait_ms": float(task.ref_spec_queue_wait_ms), + "ref_spec_wait_ms": float(task.ref_spec_result[1].get("ref_spec_wait_ms", 0.0)), + "ref_spec_ms": float(task.ref_spec_result[1].get("ref_spec_ms", 0.0)), + "ref_spec_to_device_ms": float(task.ref_spec_result[1].get("ref_spec_to_device_ms", 0.0)), + "ref_spec_main_resample_ms": float(task.ref_spec_result[1].get("ref_spec_main_resample_ms", 0.0)), + "ref_spec_norm_ms": float(task.ref_spec_result[1].get("ref_spec_norm_ms", 0.0)), + "ref_spec_spectrogram_ms": float(task.ref_spec_result[1].get("ref_spec_spectrogram_ms", 0.0)), + "ref_spec_post_resample_ms": float(task.ref_spec_result[1].get("ref_spec_post_resample_ms", 0.0)), + }, + ) + + def _mark_ref_spec_async_failed( + self, + task: EngineGpuPrepareTask, + error: Exception, + *, + queue_wait_ms: float, + ) -> None: + task.error = str(error) + task.cancelled = True + if task.state_result is not None: + task.state_result.prepare_profile["ref_spec_async_failed"] = 1.0 + task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + { + "ref_spec_async_failed": 1.0, + "engine_prepare_ref_spec_queue_wait_ms": float(queue_wait_ms), + }, + ) + self.fail_request_state(task.engine_request_id or task.request_id, str(error)) + self.fail_engine_jobs([task.request_id], str(error)) + self.notify_arbiter() + + def _run_engine_prepare_audio_once(self, batch_max_items: int) -> bool: + tasks = self.prepare_queue_owner.pop_left_many(batch_max_items) if not tasks: return False now = time.perf_counter() queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks] - batch_results = asyncio.run( - self.scheduler_worker.prepare_gpu_stages_profiled_async([task.cpu_stage for task in tasks]) - ) + for task in tasks: + task.audio_start_time = float(now) + batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_audio_phases_async([task.cpu_stage for task in tasks])) completed_count = 0 for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results): + task.audio_end_time = time.perf_counter() if isinstance(result, Exception): task.error = str(result) self.fail_request_state(task.engine_request_id or task.request_id, str(result)) self._notify_prepare_error(task, result) completed_count += 1 continue - state, prepare_exec_started_at, prepare_exec_finished_at = result - state.prepare_profile["engine_prepare_queue_admission_wait_ms"] = float(task.admission_wait_ms) - state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) + task.audio_queue_wait_ms = float(queue_wait_ms) + task.phase_one = result + task.phase = "text" + task.enqueue_time = time.perf_counter() + task.text_enqueue_time = float(task.enqueue_time) + task.ref_spec_enqueue_time = float(task.enqueue_time) + self.prepare_text_queue_owner.enqueue(task) + self.prepare_ref_spec_queue_owner.enqueue(task) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + { + "engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms), + "engine_prepare_audio_queue_wait_ms": float(queue_wait_ms), + "engine_prepare_audio_batch_size": float(len(tasks)), + "engine_prepare_audio_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)), + "engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time), + "engine_prepare_audio_start_ts": float(task.audio_start_time), + "engine_prepare_audio_end_ts": float(task.audio_end_time), + "engine_prepare_text_enqueue_ts": float(task.text_enqueue_time), + "engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time), + }, + ) + completed_count += 1 + self.prepare_queue_owner.mark_completed(completed_count) + if completed_count > 0 and self._should_chain_prepare_text_after_audio(): + self._run_engine_prepare_text_once(min(batch_max_items, completed_count)) + return True + if completed_count > 0: + self.notify_arbiter() + return True + + def _run_engine_prepare_text_once(self, batch_max_items: int) -> bool: + tasks = self.prepare_text_queue_owner.pop_left_many(batch_max_items) + if not tasks: + return False + now = time.perf_counter() + queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks] + for task in tasks: + task.text_start_time = float(now) + items = [(task.cpu_stage, task.phase_one) for task in tasks if task.phase_one is not None] + batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_text_phases_async(items)) + completed_count = 0 + for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results): + task.text_end_time = time.perf_counter() + if isinstance(result, Exception): + task.error = str(result) + task.cancelled = True + self.fail_request_state(task.engine_request_id or task.request_id, str(result)) + self._notify_prepare_error(task, result) + completed_count += 1 + continue + task.text_queue_wait_ms = float(queue_wait_ms) + state, prepare_exec_started_at, prepare_exec_finished_at = self.scheduler_worker.build_gpu_prepare_result_from_phases( + task.cpu_stage, + task.phase_one or {}, + result, + extra_profile={ + "engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms), + "engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms), + "engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms), + "engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms), + "engine_prepare_audio_batch_size": float(len(tasks)), + "engine_prepare_text_batch_size": float(len(tasks)), + "engine_prepare_audio_phase_mode": 2.0, + "engine_prepare_audio_phase_wall_ms": float((task.phase_one or {}).get("phase_wall_ms", 0.0)), + "engine_prepare_text_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)), + "engine_prepare_text_phase_batch_size": float(len(tasks)), + "engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time), + "engine_prepare_audio_start_ts": float(task.audio_start_time), + "engine_prepare_audio_end_ts": float(task.audio_end_time), + "engine_prepare_text_enqueue_ts": float(task.text_enqueue_time), + "engine_prepare_text_start_ts": float(task.text_start_time), + "engine_prepare_text_end_ts": float(task.text_end_time), + "engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time), + }, + ) + task.state_result = state + self._maybe_apply_ref_spec_to_state(task) state.prepare_profile["engine_gpu_prepare_batch_size"] = float(len(tasks)) if task.engine_request_id not in [None, ""]: self.merge_request_state_profile( str(task.engine_request_id), { "engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms), - "engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms), + "engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms), + "engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms), + "engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms), "engine_gpu_prepare_batch_size": float(len(tasks)), }, ) self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) completed_count += 1 - self.prepare_queue_owner.mark_completed(completed_count) + self.prepare_text_queue_owner.mark_completed(completed_count) return True + + def _run_engine_prepare_ref_spec_once(self, batch_max_items: int) -> bool: + tasks = self.prepare_ref_spec_queue_owner.pop_left_many(batch_max_items) + if not tasks: + return False + now = time.perf_counter() + runnable_tasks: list[EngineGpuPrepareTask] = [] + queue_wait_ms_list: list[float] = [] + completed_count = 0 + for task in tasks: + if task.cancelled or task.phase_one is None: + completed_count += 1 + continue + task.ref_spec_start_time = float(now) + runnable_tasks.append(task) + queue_wait_ms_list.append(max(0.0, (now - task.ref_spec_enqueue_time) * 1000.0)) + if not runnable_tasks: + self.prepare_ref_spec_queue_owner.mark_completed(completed_count) + return True + batch_results = asyncio.run( + self.scheduler_worker.prepare_ref_spec_stages_async([task.phase_one or {} for task in runnable_tasks]) + ) + for task, queue_wait_ms, result in zip(runnable_tasks, queue_wait_ms_list, batch_results): + task.ref_spec_end_time = time.perf_counter() + task.ref_spec_queue_wait_ms = float(queue_wait_ms) + if isinstance(result, Exception): + self._mark_ref_spec_async_failed(task, result, queue_wait_ms=float(queue_wait_ms)) + completed_count += 1 + continue + task.ref_spec_result = result + self._maybe_apply_ref_spec_to_state(task) + if task.state_result is not None: + task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms) + task.state_result.prepare_profile["engine_prepare_ref_spec_enqueue_ts"] = float(task.ref_spec_enqueue_time) + task.state_result.prepare_profile["engine_prepare_ref_spec_start_ts"] = float(task.ref_spec_start_time) + task.state_result.prepare_profile["engine_prepare_ref_spec_end_ts"] = float(task.ref_spec_end_time) + completed_count += 1 + self.prepare_ref_spec_queue_owner.mark_completed(completed_count) + return True + + def run_engine_prepare_once(self) -> bool: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + batch_max_items = int(prepare_batch_policy.get("prepare_batch_max_items", 1)) + audio_age_ms = self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") + text_age_ms = self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time") + if self.prepare_text_queue_owner.has_items() and ( + not self.prepare_queue_owner.has_items() or text_age_ms >= audio_age_ms + ): + return self._run_engine_prepare_text_once(batch_max_items) + if self.prepare_queue_owner.has_items(): + return self._run_engine_prepare_audio_once(batch_max_items) + if self.prepare_ref_spec_queue_owner.has_items(): + return self._run_engine_prepare_ref_spec_once(batch_max_items) + if self.prepare_text_queue_owner.has_items(): + return self._run_engine_prepare_text_once(batch_max_items) + if self.prepare_ref_spec_queue_owner.has_items(): + return self._run_engine_prepare_ref_spec_once(batch_max_items) + return False + + def run_engine_prepare_audio_once(self) -> bool: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + return self._run_engine_prepare_audio_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) + + def run_engine_prepare_text_once(self) -> bool: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + return self._run_engine_prepare_text_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) + + def run_engine_prepare_ref_spec_once(self) -> bool: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + return self._run_engine_prepare_ref_spec_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py index 3a675cbe..bb9cb3cb 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py @@ -151,7 +151,9 @@ class WorkerFinalizeExecutor: @staticmethod def _collect_job_refer_specs(job: SchedulerPendingJob) -> List[tuple]: - refer_specs = [job.state.refer_spec] + refer_specs = [] + if job.state.refer_spec is not None: + refer_specs.append(job.state.refer_spec) refer_specs.extend(list(getattr(job.state, "aux_refer_specs", []) or [])) return refer_specs diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py index 8b3db1fa..9fb7c8d9 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio import os import time -from typing import Callable, Dict, List +from typing import Any, Callable, Dict, List from GPT_SoVITS.TTS_infer_pack.TTS import TTS from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage @@ -81,11 +81,60 @@ class WorkerPrepareExecutor: cpu_stages: List[PreparedCpuStage], ) -> List[tuple[T2SRequestState, float, float] | Exception]: try: - return list( - await asyncio.gather( - *[self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) for cpu_stage in cpu_stages], - return_exceptions=True, - ) + return await self.coordinator.prepare_gpu_stages_profiled_async(cpu_stages) + finally: + self._notify_state_change() + + async def prepare_gpu_audio_phases_async( + self, + cpu_stages: List[PreparedCpuStage], + ) -> List[Dict[str, Any] | Exception]: + try: + return await self.coordinator.prepare_gpu_audio_phases_async(cpu_stages) + finally: + self._notify_state_change() + + async def prepare_gpu_text_phases_async( + self, + items: List[tuple[PreparedCpuStage, Dict[str, Any]]], + ) -> List[Dict[str, Any] | Exception]: + try: + return await self.coordinator.prepare_gpu_text_phases_async(items) + finally: + self._notify_state_change() + + def build_gpu_prepare_result_from_phases( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + phase_two: Dict[str, Any], + extra_profile: Dict[str, float] | None = None, + ) -> tuple[T2SRequestState, float, float]: + try: + return self.coordinator.build_gpu_prepare_result_from_phases( + cpu_stage, + phase_one, + phase_two, + extra_profile=extra_profile, ) finally: self._notify_state_change() + + async def prepare_ref_spec_stages_async( + self, + phase_ones: List[Dict[str, Any]], + ) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]: + try: + return await self.coordinator.prepare_ref_spec_stages_async(phase_ones) + finally: + self._notify_state_change() + + def apply_ref_spec_result_to_state( + self, + state: T2SRequestState, + ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]], + ) -> None: + try: + self.coordinator.apply_ref_spec_result_to_state(state, ref_spec_result) + finally: + self._notify_state_change() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py index e498e9ea..2ac636fe 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py @@ -267,3 +267,42 @@ class WorkerSubmitLifecycleMixin: cpu_stages: List[PreparedCpuStage], ) -> List[tuple[T2SRequestState, float, float] | Exception]: return await self.prepare_executor.prepare_gpu_stages_profiled_async(cpu_stages) + + async def prepare_gpu_audio_phases_async( + self, + cpu_stages: List[PreparedCpuStage], + ) -> List[Dict[str, Any] | Exception]: + return await self.prepare_executor.prepare_gpu_audio_phases_async(cpu_stages) + + async def prepare_gpu_text_phases_async( + self, + items: List[tuple[PreparedCpuStage, Dict[str, Any]]], + ) -> List[Dict[str, Any] | Exception]: + return await self.prepare_executor.prepare_gpu_text_phases_async(items) + + def build_gpu_prepare_result_from_phases( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + phase_two: Dict[str, Any], + extra_profile: Dict[str, float] | None = None, + ) -> tuple[T2SRequestState, float, float]: + return self.prepare_executor.build_gpu_prepare_result_from_phases( + cpu_stage, + phase_one, + phase_two, + extra_profile=extra_profile, + ) + + async def prepare_ref_spec_stages_async( + self, + phase_ones: List[Dict[str, Any]], + ) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]: + return await self.prepare_executor.prepare_ref_spec_stages_async(phase_ones) + + def apply_ref_spec_result_to_state( + self, + state: T2SRequestState, + ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]], + ) -> None: + self.prepare_executor.apply_ref_spec_result_to_state(state, ref_spec_result) diff --git a/GPT_SoVITS/text/g2pw/cuda_api.py b/GPT_SoVITS/text/g2pw/cuda_api.py index e1a84748..881d6123 100644 --- a/GPT_SoVITS/text/g2pw/cuda_api.py +++ b/GPT_SoVITS/text/g2pw/cuda_api.py @@ -244,6 +244,16 @@ class G2PWRuntimeWrapper: ) self.batch_worker.start() + def _sync_runtime_env_overrides(self) -> None: + os.environ["G2PW_ENABLE_CUDA_GRAPH"] = "1" if self.enable_cuda_graph else "0" + os.environ["G2PW_ENABLE_PROFILE"] = "1" if self.enable_profiling else "0" + os.environ["G2PW_DUMP_GRAPH_CACHE_STATS"] = "1" if self.dump_graph_cache_stats else "0" + os.environ["G2PW_FULL_GRAPH_CACHE_LIMIT"] = str(int(self.full_graph_cache_limit)) + os.environ["G2PW_TAIL_GRAPH_CACHE_LIMIT"] = str(int(self.tail_graph_cache_limit)) + os.environ["G2PW_ALLOW_TENSOR_CORES"] = "1" if self.allow_tensor_cores else "0" + os.environ["G2PW_USE_CUBLASLT_BIAS_EPILOGUE"] = "1" if self.use_cublaslt_bias_epilogue else "0" + os.environ["G2PW_GEMM_PRECISION"] = {0: "fp32", 1: "fp16", 2: "bf16"}.get(int(self.gemm_precision), "fp32") + def _destroy_handle(self) -> None: if self.handle: self.lib.g2pw_runtime_destroy(self.handle) @@ -268,6 +278,7 @@ class G2PWRuntimeWrapper: return "" if not message else message.decode("utf-8", errors="replace") def _create_handle(self, batch_size: int, seq_len: int) -> None: + self._sync_runtime_env_overrides() new_handle = self.lib.g2pw_runtime_create( str(self.manifest_path).encode("utf-8"), str(self.weights_path).encode("utf-8"), @@ -518,6 +529,10 @@ class G2PWRuntimeWrapper: return { "shard_index": int(self.shard_index), "enabled": bool(self.batch_enabled), + "enable_cuda_graph": bool(self.enable_cuda_graph), + "enable_profiling": bool(self.enable_profiling), + "full_graph_cache_limit": int(self.full_graph_cache_limit), + "tail_graph_cache_limit": int(self.tail_graph_cache_limit), "window_ms": float(self.batch_window_s * 1000.0), "max_requests": int(self.batch_max_requests), "max_rows": int(self.batch_max_rows),