diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 78bf7178..16bc8db8 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -996,21 +996,39 @@ class TTS: return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr) def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + spec, audio, _, _, _ = self._extract_ref_spec_profile_from_raw(raw_audio, raw_sr) + return spec, audio, raw_audio, raw_sr + + def _extract_ref_spec_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + profile = { + "ref_spec_to_device_ms": 0.0, + "ref_spec_main_resample_ms": 0.0, + "ref_spec_norm_ms": 0.0, + "ref_spec_spectrogram_ms": 0.0, + "ref_spec_post_resample_ms": 0.0, + } + to_device_start = time.perf_counter() raw_audio_device = raw_audio.to(self.configs.device).float() + profile["ref_spec_to_device_ms"] = (time.perf_counter() - to_device_start) * 1000.0 if raw_sr != self.configs.sampling_rate: + resample_start = time.perf_counter() audio = raw_audio_device if audio.shape[0] == 2: audio = audio.mean(0).unsqueeze(0) audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device) + profile["ref_spec_main_resample_ms"] = (time.perf_counter() - resample_start) * 1000.0 else: audio = raw_audio_device if audio.shape[0] == 2: audio = audio.mean(0).unsqueeze(0) + norm_start = time.perf_counter() maxx = audio.abs().max() if maxx > 1: audio /= min(2, maxx) + profile["ref_spec_norm_ms"] = (time.perf_counter() - norm_start) * 1000.0 + spec_start = time.perf_counter() spec = spectrogram_torch( audio, self.configs.filter_length, @@ -1019,15 +1037,18 @@ class TTS: self.configs.win_length, center=False, ) + profile["ref_spec_spectrogram_ms"] = (time.perf_counter() - spec_start) * 1000.0 if self.configs.is_half: spec = spec.half() if self.is_v2pro == True: + post_resample_start = time.perf_counter() audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device) + profile["ref_spec_post_resample_ms"] = (time.perf_counter() - post_resample_start) * 1000.0 if self.configs.is_half: audio = audio.half() else: audio = None - return spec, audio, raw_audio, raw_sr + return spec, audio, raw_audio, raw_sr, profile def extract_ref_spec(self, ref_audio_path: str): raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py index 65bcbb51..e74e0de4 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py @@ -228,8 +228,74 @@ class PrepareCoordinator: def _load_ref_audio_raw(self, ref_audio_path: str): return self.tts._load_ref_audio_raw(ref_audio_path) + def _build_ref_prompt_semantic_from_raw(self, raw_audio, raw_sr: int): + load_profile = {"audio_load_ms": 0.0} + if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None: + prompt_semantic, worker_profile = self.tts.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr) + return { + "prompt_semantic": prompt_semantic, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + **load_profile, + "audio_stage_wait_ms": float(worker_profile.get("prompt_semantic_wait_ms", 0.0)), + "audio_stage_slots": float(worker_profile.get("prompt_semantic_stage_slots", 0.0)), + "audio_stage_inflight_peak": float(worker_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + "prompt_semantic_ms": float( + worker_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + + worker_profile.get("prompt_semantic_forward_ms", 0.0) + + worker_profile.get("prompt_semantic_scatter_ms", 0.0) + ), + **{key: float(value) for key, value in worker_profile.items()}, + "ref_spec_wait_ms": 0.0, + "ref_spec_ms": 0.0, + "bundle_total_ms": float(worker_profile.get("prompt_semantic_wait_ms", 0.0)) + + float(worker_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(worker_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(worker_profile.get("prompt_semantic_scatter_ms", 0.0)), + }, + } + wav16k, cpu_prepare_ms, limiter_stats = self.tts._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr) + with self.tts.prepare_ref_audio_stage_limiter.enter() as stage_stats: + prompt_semantic, forward_ms = self.tts._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k) + return { + "prompt_semantic": prompt_semantic, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": 0.0, + "audio_stage_wait_ms": float(stage_stats.get("wait_ms", 0.0)), + "audio_stage_slots": float(stage_stats.get("slots", 0.0)), + "audio_stage_inflight_peak": float(stage_stats.get("peak_inflight", 0.0)), + "prompt_semantic_wait_ms": float(stage_stats.get("wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_wait_ms": float(limiter_stats.get("wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_slots": float(limiter_stats.get("slots", 0.0)), + "prompt_semantic_cpu_prepare_inflight_peak": float(limiter_stats.get("peak_inflight", 0.0)), + "prompt_semantic_worker_queue_wait_ms": 0.0, + "prompt_semantic_batch_collect_wait_ms": 0.0, + "prompt_semantic_stage_limiter_wait_ms": float(stage_stats.get("wait_ms", 0.0)), + "prompt_semantic_batch_dispatch_delay_ms": 0.0, + "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), + "prompt_semantic_pack_ms": 0.0, + "prompt_semantic_h2d_ms": 0.0, + "prompt_semantic_ssl_forward_ms": 0.0, + "prompt_semantic_hidden_length_ms": 0.0, + "prompt_semantic_extract_latent_ms": 0.0, + "prompt_semantic_forward_ms": float(forward_ms), + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_stage_slots": float(stage_stats.get("slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float(stage_stats.get("peak_inflight", 0.0)), + "prompt_semantic_batch_size": 1.0, + "prompt_semantic_batch_samples": 0.0, + "ref_spec_wait_ms": 0.0, + "ref_spec_ms": 0.0, + "bundle_total_ms": float(cpu_prepare_ms + forward_ms + stage_stats.get("wait_ms", 0.0)), + }, + } + def _extract_ref_spec_from_raw(self, raw_audio, raw_sr: int): - return self.tts._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + spec, audio, _, _, profile = self.tts._extract_ref_spec_profile_from_raw(raw_audio, raw_sr) + return (spec, audio), profile @staticmethod def _build_empty_text_features_like(reference: PreparedTextFeatures | None = None) -> PreparedTextFeatures: @@ -523,7 +589,7 @@ class PrepareCoordinator: finally: self.text_feature_gate.release() - async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult: + async def _run_ref_prompt_semantic_stage(self, ref_audio_path: str) -> ProfiledResult: if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None: submit_at = time.perf_counter() started_at = float(submit_at) @@ -538,19 +604,7 @@ class PrepareCoordinator: prompt_semantic_task = asyncio.create_task( self.tts.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr) ) - await self.ref_spec_gate.acquire() - try: - ref_spec_task = asyncio.create_task( - self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr) - ) - (prompt_semantic, prompt_semantic_profile), ref_spec_profiled = await asyncio.gather( - prompt_semantic_task, - ref_spec_task, - ) - finally: - self.ref_spec_gate.release() - - refer_spec = ref_spec_profiled.result + prompt_semantic, prompt_semantic_profile = await prompt_semantic_task limiter_snapshot = ( self.tts.prepare_ref_audio_stage_limiter.snapshot() if getattr(self.tts, "prepare_ref_audio_stage_limiter", None) is not None @@ -561,21 +615,15 @@ class PrepareCoordinator: + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) ) - audio_stage_wait_ms = ( - float(load_profiled.queue_ms) - + float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) - + float(ref_spec_profiled.queue_ms) - ) finished_at = time.perf_counter() result = { "prompt_semantic": prompt_semantic, - "refer_spec": refer_spec, "raw_audio": raw_audio, "raw_sr": raw_sr, "profile": { "audio_load_queue_ms": float(load_profiled.queue_ms), "audio_load_ms": float(load_profiled.run_ms), - "audio_stage_wait_ms": float(audio_stage_wait_ms), + "audio_stage_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), "audio_stage_slots": float( max( float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), @@ -593,6 +641,17 @@ class PrepareCoordinator: "prompt_semantic_cpu_prepare_ms": float( prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) ), + "prompt_semantic_pack_ms": float(prompt_semantic_profile.get("prompt_semantic_pack_ms", 0.0)), + "prompt_semantic_h2d_ms": float(prompt_semantic_profile.get("prompt_semantic_h2d_ms", 0.0)), + "prompt_semantic_ssl_forward_ms": float( + prompt_semantic_profile.get("prompt_semantic_ssl_forward_ms", 0.0) + ), + "prompt_semantic_hidden_length_ms": float( + prompt_semantic_profile.get("prompt_semantic_hidden_length_ms", 0.0) + ), + "prompt_semantic_extract_latent_ms": float( + prompt_semantic_profile.get("prompt_semantic_extract_latent_ms", 0.0) + ), "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), @@ -603,14 +662,10 @@ class PrepareCoordinator: "prompt_semantic_batch_samples": float( prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) ), - "ref_spec_wait_ms": float(ref_spec_profiled.queue_ms), - "ref_spec_ms": float(ref_spec_profiled.run_ms), "bundle_total_ms": float( load_profiled.queue_ms + load_profiled.run_ms + prompt_semantic_ms - + ref_spec_profiled.queue_ms - + ref_spec_profiled.run_ms ), }, } @@ -623,21 +678,26 @@ class PrepareCoordinator: await self.ref_audio_gate.acquire() try: - if hasattr(self.tts, "extract_ref_audio_bundle_async"): - submit_at = time.perf_counter() - started_at = time.perf_counter() - result = await self.tts.extract_ref_audio_bundle_async(ref_audio_path) - finished_at = time.perf_counter() - return ProfiledResult( - result=result, - submit_at=float(submit_at), - started_at=float(started_at), - finished_at=float(finished_at), - ) - return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path) + load_profiled = await self._run_on_executor(self.ref_audio_executor, self._load_ref_audio_raw, ref_audio_path) + raw_audio, raw_sr = load_profiled.result + submit_at = time.perf_counter() + started_at = time.perf_counter() + result = await asyncio.to_thread(self._build_ref_prompt_semantic_from_raw, raw_audio, raw_sr) + result.setdefault("profile", {}) + result["profile"]["audio_load_queue_ms"] = float(load_profiled.queue_ms) + result["profile"]["audio_load_ms"] = float(load_profiled.run_ms) + finished_at = time.perf_counter() + return ProfiledResult(result=result, submit_at=float(submit_at), started_at=float(started_at), finished_at=float(finished_at)) finally: self.ref_audio_gate.release() + async def _run_ref_spec_stage(self, raw_audio, raw_sr: int) -> ProfiledResult: + await self.ref_spec_gate.acquire() + try: + return await self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr) + finally: + self.ref_spec_gate.release() + def _release_split_stage_slot(self) -> None: self._mark_leave() self._inflight_gate.release() @@ -682,101 +742,318 @@ class PrepareCoordinator: cpu_stage: PreparedCpuStage, ) -> tuple[T2SRequestState, float, float]: try: - g2pw_pair_start = time.perf_counter() - g2pw_pair_task = asyncio.create_task( - self._run_g2pw_pair_stage( - cpu_stage.prompt_cpu_profiled.result, - cpu_stage.target_cpu_profiled.result, - ) - ) - ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(cpu_stage.spec.ref_audio_path))) - (prompt_g2pw_profiled, target_g2pw_profiled), ref_audio_profiled = await asyncio.gather( - g2pw_pair_task, - ref_audio_task, - ) - g2pw_pair_end = time.perf_counter() - text_pair_start = time.perf_counter() - text_feature_pair_task = asyncio.create_task( - self._run_text_feature_pair_stage( - prompt_g2pw_profiled.result, - target_g2pw_profiled.result, - cpu_stage.prompt_cpu_profiled.run_ms, - cpu_stage.target_cpu_profiled.run_ms, - prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}), - target_base_profile=dict(target_g2pw_profiled.profile or {}), - ) - ) - prompt_feature_profiled, target_feature_profiled = await text_feature_pair_task - text_pair_end = time.perf_counter() - state = build_request_state_from_parts( - tts=self.tts, - spec=cpu_stage.spec, - prompt_text=cpu_stage.prompt_text, - text=cpu_stage.text, - prompt_result=prompt_feature_profiled.result, - target_result=target_feature_profiled.result, - ref_audio_bundle=ref_audio_profiled.result, - prepare_start=cpu_stage.prepare_start, - prepare_sync_start=cpu_stage.prepare_start, - profile_overrides={ - "executor_queue_ms": max(0.0, (cpu_stage.prepare_start - cpu_stage.prepare_submit_at) * 1000.0), - "prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms, - "executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0), - "text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0), - "g2pw_pair_ms": max(0.0, (g2pw_pair_end - g2pw_pair_start) * 1000.0), - "prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms, - "prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms, - "prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), - "prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), - "prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), - "text_g2pw_queue_ms": target_g2pw_profiled.queue_ms, - "text_g2pw_run_ms": target_g2pw_profiled.run_ms, - "text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), - "text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), - "text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), - "prompt_text_parallel_future_wait_ms": 0.0, - "prompt_text_parallel_future_executor_queue_ms": 0.0, - "prompt_text_parallel_future_run_ms": 0.0, - "prompt_text_parallel_future_finish_after_submit_ms": 0.0, - "prompt_text_parallel_future_queue_tail_after_target_ms": 0.0, - "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, - "prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms, - "prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms, - "prompt_text_cpu_admission_wait_ms": float( - (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) - ), - "prompt_text_cpu_backpressure_wait_ms": float( - (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) - ), - "prompt_text_cpu_capacity_wait_ms": float( - (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) - ), - "prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms, - "prompt_text_feature_run_ms": prompt_feature_profiled.run_ms, - "text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms, - "text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms, - "text_cpu_admission_wait_ms": float( - (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) - ), - "text_cpu_backpressure_wait_ms": float( - (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) - ), - "text_cpu_capacity_wait_ms": float( - (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) - ), - "text_feature_queue_ms": target_feature_profiled.queue_ms, - "text_feature_run_ms": target_feature_profiled.run_ms, - "ref_audio_task_queue_ms": ref_audio_profiled.queue_ms, - "ref_audio_task_run_ms": ref_audio_profiled.run_ms, - "worker_prepare_inflight_on_enter": float(cpu_stage.current_inflight), - "worker_prepare_peak_inflight": float(cpu_stage.peak_inflight), + phase_one = await self._prepare_gpu_phase_one(cpu_stage) + phase_two = await self._prepare_gpu_phase_two(cpu_stage, phase_one) + return self._build_gpu_prepare_result( + cpu_stage, + phase_one, + phase_two, + extra_profile={ + "engine_prepare_audio_phase_mode": 0.0, + "engine_prepare_audio_phase_wall_ms": float(phase_one["phase_wall_ms"]), + "engine_prepare_audio_phase_batch_size": 1.0, + "engine_prepare_text_phase_wall_ms": float(phase_two["phase_wall_ms"]), + "engine_prepare_text_phase_batch_size": 1.0, }, ) - prepare_exec_finished_at = time.perf_counter() - state.prepare_profile["executor_run_wall_ms"] = max( - 0.0, (prepare_exec_finished_at - cpu_stage.prepare_start) * 1000.0 + finally: + self._release_split_stage_slot() + + async def _prepare_gpu_phase_one(self, cpu_stage: PreparedCpuStage) -> Dict[str, Any]: + phase_start = time.perf_counter() + g2pw_pair_task = asyncio.create_task( + self._run_g2pw_pair_stage( + cpu_stage.prompt_cpu_profiled.result, + cpu_stage.target_cpu_profiled.result, ) - return state, cpu_stage.prepare_start, prepare_exec_finished_at + ) + ref_audio_task = asyncio.create_task(self._run_ref_prompt_semantic_stage(str(cpu_stage.spec.ref_audio_path))) + prompt_g2pw_profiled, target_g2pw_profiled = await g2pw_pair_task + g2pw_pair_end = time.perf_counter() + ref_audio_profiled = await ref_audio_task + phase_end = time.perf_counter() + return { + "prompt_g2pw_profiled": prompt_g2pw_profiled, + "target_g2pw_profiled": target_g2pw_profiled, + "ref_audio_profiled": ref_audio_profiled, + "ref_spec_result": None, + "g2pw_pair_ms": max(0.0, (g2pw_pair_end - phase_start) * 1000.0), + "phase_wall_ms": max(0.0, (phase_end - phase_start) * 1000.0), + } + + async def _prepare_gpu_phase_two( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + ) -> Dict[str, Any]: + phase_start = time.perf_counter() + prompt_g2pw_profiled = phase_one["prompt_g2pw_profiled"] + target_g2pw_profiled = phase_one["target_g2pw_profiled"] + prompt_feature_profiled, target_feature_profiled = await self._run_text_feature_pair_stage( + prompt_g2pw_profiled.result, + target_g2pw_profiled.result, + cpu_stage.prompt_cpu_profiled.run_ms, + cpu_stage.target_cpu_profiled.run_ms, + prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}), + target_base_profile=dict(target_g2pw_profiled.profile or {}), + ) + phase_end = time.perf_counter() + return { + "prompt_feature_profiled": prompt_feature_profiled, + "target_feature_profiled": target_feature_profiled, + "phase_wall_ms": max(0.0, (phase_end - phase_start) * 1000.0), + } + + def _build_gpu_prepare_result( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + phase_two: Dict[str, Any], + extra_profile: Dict[str, float] | None = None, + ) -> tuple[T2SRequestState, float, float]: + prompt_g2pw_profiled = phase_one["prompt_g2pw_profiled"] + target_g2pw_profiled = phase_one["target_g2pw_profiled"] + ref_audio_profiled = phase_one["ref_audio_profiled"] + ref_spec_result = phase_one.get("ref_spec_result") + prompt_feature_profiled = phase_two["prompt_feature_profiled"] + target_feature_profiled = phase_two["target_feature_profiled"] + profile_overrides = { + "executor_queue_ms": max(0.0, (cpu_stage.prepare_start - cpu_stage.prepare_submit_at) * 1000.0), + "prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms, + "prepare_submit_ts": float(cpu_stage.prepare_submit_at), + "prepare_cpu_start_ts": float(cpu_stage.prepare_start), + "prepare_cpu_done_ts": float( + max(cpu_stage.prompt_cpu_profiled.finished_at, cpu_stage.target_cpu_profiled.finished_at) + ), + "prompt_text_cpu_start_ts": float(cpu_stage.prompt_cpu_profiled.started_at), + "prompt_text_cpu_end_ts": float(cpu_stage.prompt_cpu_profiled.finished_at), + "text_cpu_start_ts": float(cpu_stage.target_cpu_profiled.started_at), + "text_cpu_end_ts": float(cpu_stage.target_cpu_profiled.finished_at), + "executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0), + "text_feature_pair_ms": float(phase_two["phase_wall_ms"]), + "g2pw_pair_ms": float(phase_one["g2pw_pair_ms"]), + "prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms, + "prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms, + "prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), + "prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), + "prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), + "text_g2pw_queue_ms": target_g2pw_profiled.queue_ms, + "text_g2pw_run_ms": target_g2pw_profiled.run_ms, + "text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), + "text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), + "text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), + "prompt_text_parallel_future_wait_ms": 0.0, + "prompt_text_parallel_future_executor_queue_ms": 0.0, + "prompt_text_parallel_future_run_ms": 0.0, + "prompt_text_parallel_future_finish_after_submit_ms": 0.0, + "prompt_text_parallel_future_queue_tail_after_target_ms": 0.0, + "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, + "prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms, + "prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms, + "prompt_text_cpu_admission_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) + ), + "prompt_text_cpu_backpressure_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) + ), + "prompt_text_cpu_capacity_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) + ), + "prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms, + "prompt_text_feature_run_ms": prompt_feature_profiled.run_ms, + "text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms, + "text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms, + "text_cpu_admission_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) + ), + "text_cpu_backpressure_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) + ), + "text_cpu_capacity_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) + ), + "text_feature_queue_ms": target_feature_profiled.queue_ms, + "text_feature_run_ms": target_feature_profiled.run_ms, + "ref_audio_task_queue_ms": ref_audio_profiled.queue_ms, + "ref_audio_task_run_ms": ref_audio_profiled.run_ms, + "worker_prepare_inflight_on_enter": float(cpu_stage.current_inflight), + "worker_prepare_peak_inflight": float(cpu_stage.peak_inflight), + } + if extra_profile: + profile_overrides.update({key: float(value) for key, value in extra_profile.items()}) + ref_audio_bundle = dict(ref_audio_profiled.result) + ref_audio_profile = dict(ref_audio_bundle.get("profile", {})) + if ref_spec_result is not None: + refer_spec, ref_spec_profiled = ref_spec_result + ref_audio_bundle["refer_spec"] = refer_spec + ref_audio_profile.update( + { + "ref_spec_wait_ms": float(ref_spec_profiled.get("ref_spec_wait_ms", 0.0)), + "ref_spec_ms": float(ref_spec_profiled.get("ref_spec_ms", 0.0)), + "ref_spec_to_device_ms": float(ref_spec_profiled.get("ref_spec_to_device_ms", 0.0)), + "ref_spec_main_resample_ms": float(ref_spec_profiled.get("ref_spec_main_resample_ms", 0.0)), + "ref_spec_norm_ms": float(ref_spec_profiled.get("ref_spec_norm_ms", 0.0)), + "ref_spec_spectrogram_ms": float(ref_spec_profiled.get("ref_spec_spectrogram_ms", 0.0)), + "ref_spec_post_resample_ms": float(ref_spec_profiled.get("ref_spec_post_resample_ms", 0.0)), + } + ) + else: + ref_audio_bundle["refer_spec"] = None + ref_audio_profile.setdefault("ref_spec_wait_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_to_device_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_main_resample_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_norm_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_spectrogram_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_post_resample_ms", 0.0) + ref_audio_bundle["profile"] = ref_audio_profile + state = build_request_state_from_parts( + tts=self.tts, + spec=cpu_stage.spec, + prompt_text=cpu_stage.prompt_text, + text=cpu_stage.text, + prompt_result=prompt_feature_profiled.result, + target_result=target_feature_profiled.result, + ref_audio_bundle=ref_audio_bundle, + prepare_start=cpu_stage.prepare_start, + prepare_sync_start=cpu_stage.prepare_start, + profile_overrides=profile_overrides, + ) + prepare_exec_finished_at = time.perf_counter() + state.prepare_profile["executor_run_wall_ms"] = max(0.0, (prepare_exec_finished_at - cpu_stage.prepare_start) * 1000.0) + return state, cpu_stage.prepare_start, prepare_exec_finished_at + + async def prepare_ref_spec_stages_async( + self, + phase_ones: list[Dict[str, Any]], + ) -> list[tuple[tuple[Any, Any], Dict[str, float]] | Exception]: + async def _one(phase_one: Dict[str, Any]): + ref_audio_profiled = phase_one["ref_audio_profiled"] + raw_audio = ref_audio_profiled.result["raw_audio"] + raw_sr = int(ref_audio_profiled.result["raw_sr"]) + profiled = await self._run_ref_spec_stage(raw_audio, raw_sr) + refer_spec, profile = profiled.result + merged_profile = dict(profile) + merged_profile["ref_spec_wait_ms"] = float(profiled.queue_ms) + merged_profile["ref_spec_ms"] = float(profiled.run_ms) + return refer_spec, merged_profile + + if not phase_ones: + return [] + return list(await asyncio.gather(*[_one(phase_one) for phase_one in phase_ones], return_exceptions=True)) + + def apply_ref_spec_result_to_state( + self, + state: T2SRequestState, + ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]], + ) -> None: + refer_spec, profile = ref_spec_result + state.refer_spec = refer_spec + state.prepare_profile["ref_spec_wait_ms"] = float(profile.get("ref_spec_wait_ms", 0.0)) + state.prepare_profile["ref_spec_ms"] = float(profile.get("ref_spec_ms", 0.0)) + state.prepare_profile["ref_spec_to_device_ms"] = float(profile.get("ref_spec_to_device_ms", 0.0)) + state.prepare_profile["ref_spec_main_resample_ms"] = float(profile.get("ref_spec_main_resample_ms", 0.0)) + state.prepare_profile["ref_spec_norm_ms"] = float(profile.get("ref_spec_norm_ms", 0.0)) + state.prepare_profile["ref_spec_spectrogram_ms"] = float(profile.get("ref_spec_spectrogram_ms", 0.0)) + state.prepare_profile["ref_spec_post_resample_ms"] = float(profile.get("ref_spec_post_resample_ms", 0.0)) + + async def prepare_gpu_stages_profiled_async( + self, + cpu_stages: list[PreparedCpuStage], + ) -> list[tuple[T2SRequestState, float, float] | Exception]: + if not cpu_stages: + return [] + if len(cpu_stages) == 1: + single_stage = cpu_stages[0] + try: + return [await self.prepare_gpu_stage_profiled_async(single_stage)] + except Exception as exc: # noqa: PERF203 + return [exc] + + phase_one_started_at = time.perf_counter() + phase_one_results = await asyncio.gather( + *[self._prepare_gpu_phase_one(cpu_stage) for cpu_stage in cpu_stages], + return_exceptions=True, + ) + phase_one_finished_at = time.perf_counter() + phase_one_wall_ms = max(0.0, (phase_one_finished_at - phase_one_started_at) * 1000.0) + + outputs: list[tuple[T2SRequestState, float, float] | Exception | None] = [None] * len(cpu_stages) + pending_phase_two: list[tuple[int, PreparedCpuStage, Dict[str, Any]]] = [] + for index, (cpu_stage, phase_one) in enumerate(zip(cpu_stages, phase_one_results)): + if isinstance(phase_one, Exception): + outputs[index] = phase_one + self._release_split_stage_slot() + continue + pending_phase_two.append((index, cpu_stage, phase_one)) + + phase_two_started_at = time.perf_counter() + phase_two_results = await asyncio.gather( + *[self._prepare_gpu_phase_two(cpu_stage, phase_one) for _, cpu_stage, phase_one in pending_phase_two], + return_exceptions=True, + ) + phase_two_finished_at = time.perf_counter() + phase_two_wall_ms = max(0.0, (phase_two_finished_at - phase_two_started_at) * 1000.0) + + for (index, cpu_stage, phase_one), phase_two in zip(pending_phase_two, phase_two_results): + try: + if isinstance(phase_two, Exception): + outputs[index] = phase_two + continue + outputs[index] = self._build_gpu_prepare_result( + cpu_stage, + phase_one, + phase_two, + extra_profile={ + "engine_prepare_audio_phase_mode": 1.0, + "engine_prepare_audio_phase_wall_ms": float(phase_one_wall_ms), + "engine_prepare_audio_phase_batch_size": float(len(cpu_stages)), + "engine_prepare_text_phase_wall_ms": float(phase_two_wall_ms), + "engine_prepare_text_phase_batch_size": float(len(pending_phase_two)), + }, + ) + except Exception as exc: # noqa: PERF203 + outputs[index] = exc + finally: + self._release_split_stage_slot() + + return [item if item is not None else RuntimeError("prepare batch result missing") for item in outputs] + + async def prepare_gpu_audio_phases_async( + self, + cpu_stages: list[PreparedCpuStage], + ) -> list[Dict[str, Any] | Exception]: + if not cpu_stages: + return [] + return list( + await asyncio.gather( + *[self._prepare_gpu_phase_one(cpu_stage) for cpu_stage in cpu_stages], + return_exceptions=True, + ) + ) + + async def prepare_gpu_text_phases_async( + self, + items: list[tuple[PreparedCpuStage, Dict[str, Any]]], + ) -> list[Dict[str, Any] | Exception]: + if not items: + return [] + return list( + await asyncio.gather( + *[self._prepare_gpu_phase_two(cpu_stage, phase_one) for cpu_stage, phase_one in items], + return_exceptions=True, + ) + ) + + def build_gpu_prepare_result_from_phases( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + phase_two: Dict[str, Any], + extra_profile: Dict[str, float] | None = None, + ) -> tuple[T2SRequestState, float, float]: + try: + return self._build_gpu_prepare_result(cpu_stage, phase_one, phase_two, extra_profile=extra_profile) finally: self._release_split_stage_slot() diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py index ff5591b2..4628a2a2 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py @@ -15,6 +15,7 @@ REF_AUDIO_MIN_SAMPLES_16K = 48000 REF_AUDIO_MAX_SAMPLES_16K = 160000 _RESAMPLE_CACHE_LOCK = threading.Lock() _RESAMPLE_CACHE: Dict[Tuple[int, int, str], torchaudio.transforms.Resample] = {} +_RESAMPLE_STREAM_CACHE: Dict[str, torch.cuda.Stream] = {} def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.transforms.Resample: @@ -28,6 +29,16 @@ def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.tran return transform +def _get_resample_stream(device: str) -> torch.cuda.Stream: + device_key = str(device) + with _RESAMPLE_CACHE_LOCK: + stream = _RESAMPLE_STREAM_CACHE.get(device_key) + if stream is None: + stream = torch.cuda.Stream(device=device_key) + _RESAMPLE_STREAM_CACHE[device_key] = stream + return stream + + def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor: resample_device = os.environ.get("GPTSOVITS_PREPARE_REF_RESAMPLE_DEVICE", "cpu").strip().lower() or "cpu" if resample_device not in {"cpu", "cuda"}: @@ -37,10 +48,20 @@ def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wa wav_mono = raw_audio if wav_mono.dim() == 2 and wav_mono.shape[0] != 1: wav_mono = wav_mono.mean(0, keepdim=True) - wav16k = wav_mono.to(dtype=torch.float32, device=resample_device) - if raw_sr != 16000: - wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k) - wav16k = wav16k.squeeze(0).contiguous() + if resample_device == "cuda": + stream = _get_resample_stream(resample_device) + with torch.cuda.stream(stream): + wav16k = wav_mono.to(dtype=torch.float32, device=resample_device) + if raw_sr != 16000: + wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k) + wav16k = wav16k.squeeze(0).contiguous() + stream.synchronize() + wav16k = wav16k.detach().to(device="cpu", dtype=torch.float32).contiguous() + else: + wav16k = wav_mono.to(dtype=torch.float32, device=resample_device) + if raw_sr != 16000: + wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k) + wav16k = wav16k.squeeze(0).contiguous() if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K: raise OSError("参考音频在3~10秒范围外,请更换!") if zero_wav_samples > 0: @@ -256,37 +277,56 @@ class PrepareRefSemanticBatchWorker: batch_samples = int(wav_lengths.sum().item()) max_wav_len = int(wav_lengths.max().item()) + pack_start = time.perf_counter() input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32) attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long) for batch_index, wav in enumerate(prepared_wavs): wav_len = int(wav.shape[0]) input_values_cpu[batch_index, :wav_len] = wav attention_mask_cpu[batch_index, :wav_len] = 1 + pack_ms = (time.perf_counter() - pack_start) * 1000.0 limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} + h2d_ms = 0.0 + ssl_forward_ms = 0.0 + hidden_length_ms = 0.0 + extract_latent_ms = 0.0 if self.stage_limiter is None: + h2d_start = time.perf_counter() input_values = input_values_cpu.to(self.device) attention_mask = attention_mask_cpu.to(self.device) if self.is_half: input_values = input_values.half() - forward_start = time.perf_counter() + h2d_ms = (time.perf_counter() - h2d_start) * 1000.0 + ssl_start = time.perf_counter() outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) + ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0 hubert_feature = outputs["last_hidden_state"].transpose(1, 2) + hidden_length_start = time.perf_counter() hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) + hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0 + latent_start = time.perf_counter() codes = self.vits_model.extract_latent(hubert_feature) - forward_ms = (time.perf_counter() - forward_start) * 1000.0 + extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0 else: with self.stage_limiter.enter() as limiter_stats: + h2d_start = time.perf_counter() input_values = input_values_cpu.to(self.device) attention_mask = attention_mask_cpu.to(self.device) if self.is_half: input_values = input_values.half() - forward_start = time.perf_counter() + h2d_ms = (time.perf_counter() - h2d_start) * 1000.0 + ssl_start = time.perf_counter() outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) + ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0 hubert_feature = outputs["last_hidden_state"].transpose(1, 2) + hidden_length_start = time.perf_counter() hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) + hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0 + latent_start = time.perf_counter() codes = self.vits_model.extract_latent(hubert_feature) - forward_ms = (time.perf_counter() - forward_start) * 1000.0 + extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0 + forward_ms = float(h2d_ms + ssl_forward_ms + hidden_length_ms + extract_latent_ms) code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None)) scatter_start = time.perf_counter() @@ -308,6 +348,11 @@ class PrepareRefSemanticBatchWorker: 0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0 ), "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), + "prompt_semantic_pack_ms": float(pack_ms), + "prompt_semantic_h2d_ms": float(h2d_ms), + "prompt_semantic_ssl_forward_ms": float(ssl_forward_ms), + "prompt_semantic_hidden_length_ms": float(hidden_length_ms), + "prompt_semantic_extract_latent_ms": float(extract_latent_ms), "prompt_semantic_forward_ms": float(forward_ms), "prompt_semantic_scatter_ms": 0.0, "prompt_semantic_calls": 1.0, diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index 73e2a2c7..a4d462d0 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -55,7 +55,7 @@ class T2SRequestState: all_phones: torch.LongTensor all_bert_features: torch.Tensor prompt_semantic: torch.LongTensor - refer_spec: Tuple[torch.Tensor, Optional[torch.Tensor]] + refer_spec: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] raw_audio: torch.Tensor raw_sr: int @@ -188,7 +188,11 @@ def build_request_state_from_parts( ref_audio_bundle_ms = float(ref_audio_bundle.get("profile", {}).get("bundle_total_ms", 0.0)) bundle_profile = ref_audio_bundle.get("profile", {}) prompt_semantic = ref_audio_bundle["prompt_semantic"].long() - spec_audio, audio_16k = ref_audio_bundle["refer_spec"] + refer_spec_value = ref_audio_bundle.get("refer_spec") + if refer_spec_value in [None, ()]: + spec_audio, audio_16k = None, None + else: + spec_audio, audio_16k = refer_spec_value aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] = [] for aux_ref_audio_path in list(getattr(spec, "aux_ref_audio_paths", []) or []): if aux_ref_audio_path in [None, ""]: @@ -323,6 +327,11 @@ def build_request_state_from_parts( bundle_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) ), "prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "prompt_semantic_pack_ms": float(bundle_profile.get("prompt_semantic_pack_ms", 0.0)), + "prompt_semantic_h2d_ms": float(bundle_profile.get("prompt_semantic_h2d_ms", 0.0)), + "prompt_semantic_ssl_forward_ms": float(bundle_profile.get("prompt_semantic_ssl_forward_ms", 0.0)), + "prompt_semantic_hidden_length_ms": float(bundle_profile.get("prompt_semantic_hidden_length_ms", 0.0)), + "prompt_semantic_extract_latent_ms": float(bundle_profile.get("prompt_semantic_extract_latent_ms", 0.0)), "prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)), "prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)), "prompt_semantic_stage_slots": float(bundle_profile.get("prompt_semantic_stage_slots", 0.0)), @@ -331,6 +340,11 @@ def build_request_state_from_parts( "prompt_semantic_batch_samples": float(bundle_profile.get("prompt_semantic_batch_samples", 0.0)), "ref_spec_wait_ms": float(bundle_profile.get("ref_spec_wait_ms", 0.0)), "ref_spec_ms": ref_spec_ms, + "ref_spec_to_device_ms": float(bundle_profile.get("ref_spec_to_device_ms", 0.0)), + "ref_spec_main_resample_ms": float(bundle_profile.get("ref_spec_main_resample_ms", 0.0)), + "ref_spec_norm_ms": float(bundle_profile.get("ref_spec_norm_ms", 0.0)), + "ref_spec_spectrogram_ms": float(bundle_profile.get("ref_spec_spectrogram_ms", 0.0)), + "ref_spec_post_resample_ms": float(bundle_profile.get("ref_spec_post_resample_ms", 0.0)), "ref_audio_bundle_ms": ref_audio_bundle_ms, "tensorize_ms": tensorize_ms, "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, @@ -352,7 +366,7 @@ def build_request_state_from_parts( all_phones=all_phones, all_bert_features=all_bert_features, prompt_semantic=prompt_semantic, - refer_spec=(spec_audio, audio_16k), + refer_spec=(None if spec_audio is None else (spec_audio, audio_16k)), aux_refer_specs=aux_refer_specs, raw_audio=raw_audio, raw_sr=raw_sr, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py index 88b8cc5d..f07250e1 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py @@ -21,6 +21,14 @@ class EngineRegistryBridgeFacade: def engine_prepare_queue_owner(self): return self.owner.engine_prepare_queue_owner + @property + def engine_prepare_text_queue_owner(self): + return self.owner.engine_prepare_text_queue_owner + + @property + def engine_prepare_ref_spec_queue_owner(self): + return self.owner.engine_prepare_ref_spec_queue_owner + @property def engine_finalize_queue_owner(self): return self.owner.engine_finalize_queue_owner @@ -82,7 +90,33 @@ class EngineRegistryBridgeFacade: return self.request_registry.snapshot() def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: - return self.engine_prepare_queue_owner.snapshot(max_request_ids=16) + audio_snapshot = self.engine_prepare_queue_owner.snapshot(max_request_ids=16) + text_snapshot = self.engine_prepare_text_queue_owner.snapshot(max_request_ids=16) + ref_spec_snapshot = self.engine_prepare_ref_spec_queue_owner.snapshot(max_request_ids=16) + return { + "waiting_count": int(audio_snapshot.get("waiting_count", 0)) + + int(text_snapshot.get("waiting_count", 0)) + + int(ref_spec_snapshot.get("waiting_count", 0)), + "audio_waiting_count": int(audio_snapshot.get("waiting_count", 0)), + "text_waiting_count": int(text_snapshot.get("waiting_count", 0)), + "ref_spec_waiting_count": int(ref_spec_snapshot.get("waiting_count", 0)), + "audio_waiting_request_ids": list(audio_snapshot.get("waiting_request_ids", [])), + "text_waiting_request_ids": list(text_snapshot.get("waiting_request_ids", [])), + "ref_spec_waiting_request_ids": list(ref_spec_snapshot.get("waiting_request_ids", [])), + "peak_waiting": int( + max( + int(audio_snapshot.get("peak_waiting", 0)), + int(text_snapshot.get("peak_waiting", 0)), + int(ref_spec_snapshot.get("peak_waiting", 0)), + ) + ), + "total_submitted": int(audio_snapshot.get("total_submitted", 0)), + "total_completed": int(audio_snapshot.get("total_completed", 0)), + "text_total_submitted": int(text_snapshot.get("total_submitted", 0)), + "text_total_completed": int(text_snapshot.get("total_completed", 0)), + "ref_spec_total_submitted": int(ref_spec_snapshot.get("total_submitted", 0)), + "ref_spec_total_completed": int(ref_spec_snapshot.get("total_completed", 0)), + } def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) @@ -107,6 +141,8 @@ class EngineRegistryBridgeFacade: def _is_engine_drained(self) -> bool: prepare_empty = self.engine_prepare_queue_owner.is_drained() + prepare_text_empty = self.engine_prepare_text_queue_owner.is_drained() + prepare_ref_spec_empty = self.engine_prepare_ref_spec_queue_owner.is_drained() dispatch_empty = self.engine_dispatch_queue_owner.is_drained() finalize_empty = self.engine_finalize_queue_owner.is_drained() decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() @@ -114,6 +150,8 @@ class EngineRegistryBridgeFacade: worker_state = self.scheduler_worker.snapshot() return bool( prepare_empty + and prepare_text_empty + and prepare_ref_spec_empty and dispatch_empty and finalize_empty and decode_pending_empty diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py index 0e93c442..2cb1e175 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py @@ -110,6 +110,8 @@ class EngineCompositionBuilder: get_micro_batch_wait_s=owner.scheduler_worker.get_micro_batch_wait_s, ) owner.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + owner.engine_prepare_text_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + owner.engine_prepare_ref_spec_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") owner.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") owner.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched") @@ -119,6 +121,8 @@ class EngineCompositionBuilder: tts=owner.tts, scheduler_worker=owner.scheduler_worker, prepare_queue_owner=owner.engine_prepare_queue_owner, + prepare_text_queue_owner=owner.engine_prepare_text_queue_owner, + prepare_ref_spec_queue_owner=owner.engine_prepare_ref_spec_queue_owner, finalize_queue_owner=owner.engine_finalize_queue_owner, dispatch_queue_owner=owner.engine_dispatch_queue_owner, decode_runtime_owner=owner.engine_decode_runtime_owner, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py index b6c5ca4d..65953dd9 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py @@ -129,7 +129,7 @@ class EnginePolicyArbiterController: self.state.total_ticks += 1 if stage == "idle": self.state.total_idle_ticks += 1 - elif stage == "prepare": + elif stage in {"prepare", "prepare_audio", "prepare_text", "prepare_ref_spec"}: self.state.total_prepare_dispatches += 1 self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) elif stage == "finalize": @@ -282,7 +282,11 @@ class EnginePolicyArbiterController: request_registry = self.snapshot_request_registry() worker_state = self.get_worker_state() policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) - prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0)) + prepare_state = self.snapshot_prepare_state() + prepare_waiting = int(prepare_state.get("waiting_count", 0)) + prepare_audio_waiting = int(prepare_state.get("audio_waiting_count", 0)) + prepare_text_waiting = int(prepare_state.get("text_waiting_count", 0)) + prepare_ref_spec_waiting = int(prepare_state.get("ref_spec_waiting_count", 0)) finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) decode_runtime_state = self.snapshot_decode_runtime_state() @@ -291,6 +295,9 @@ class EnginePolicyArbiterController: worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) prepare_age_ms = float(self.peek_queue_age_ms("prepare")) + prepare_audio_age_ms = float(self.peek_queue_age_ms("prepare_audio")) + prepare_text_age_ms = float(self.peek_queue_age_ms("prepare_text")) + prepare_ref_spec_age_ms = float(self.peek_queue_age_ms("prepare_ref_spec")) finalize_age_ms = float(self.peek_queue_age_ms("finalize")) decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) @@ -316,14 +323,31 @@ class EnginePolicyArbiterController: and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) ): return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state + if ( + finalize_waiting > 0 + and prepare_ref_spec_waiting > 0 + and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0) + ): + return "prepare_ref_spec", "finalize_waiting_for_ref_spec", policy_snapshot, worker_state if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): return "finalize", "finalize_aging", policy_snapshot, worker_state if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): - return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms): + return "prepare_text", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0 and prepare_text_waiting <= 0: + return "prepare_ref_spec", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + return "prepare_audio", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): - return "prepare", "prepare_aging", policy_snapshot, worker_state + if prepare_text_waiting > 0 and prepare_text_age_ms >= max(prepare_audio_age_ms, prepare_age_ms - 1e-6): + return "prepare_text", "prepare_aging", policy_snapshot, worker_state + if ( + prepare_ref_spec_waiting > 0 + and prepare_ref_spec_age_ms >= max(prepare_audio_age_ms, prepare_text_age_ms, prepare_age_ms - 1e-6) + ): + return "prepare_ref_spec", "prepare_aging", policy_snapshot, worker_state + return "prepare_audio", "prepare_aging", policy_snapshot, worker_state if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state if decode_waiting > 0 and policy_allowed: @@ -331,5 +355,9 @@ class EnginePolicyArbiterController: if finalize_waiting > 0: return "finalize", "finalize_fallback", policy_snapshot, worker_state if prepare_waiting > 0: - return "prepare", "prepare_fallback", policy_snapshot, worker_state + if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms): + return "prepare_text", "prepare_fallback", policy_snapshot, worker_state + if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0: + return "prepare_ref_spec", "prepare_fallback", policy_snapshot, worker_state + return "prepare_audio", "prepare_fallback", policy_snapshot, worker_state return "idle", "no_pending_work", policy_snapshot, worker_state diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py index 15eedeca..600e1e83 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py @@ -324,8 +324,24 @@ class EngineGpuPrepareTask: done_future: asyncio.Future | None engine_request_id: str | None enqueue_time: float - queue_wait_ms: float = 0.0 + phase: str = "audio" + audio_enqueue_time: float = 0.0 + audio_start_time: float = 0.0 + audio_end_time: float = 0.0 + text_enqueue_time: float = 0.0 + text_start_time: float = 0.0 + text_end_time: float = 0.0 + ref_spec_enqueue_time: float = 0.0 + ref_spec_start_time: float = 0.0 + ref_spec_end_time: float = 0.0 + audio_queue_wait_ms: float = 0.0 + text_queue_wait_ms: float = 0.0 + ref_spec_queue_wait_ms: float = 0.0 admission_wait_ms: float = 0.0 + phase_one: Dict[str, Any] | None = None + ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]] | None = None + state_result: T2SRequestState | None = None + cancelled: bool = False error: str | None = None diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py index a71f7e4e..0c73616f 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py @@ -14,6 +14,8 @@ class EngineStageOrchestrator: executor: EngineStageExecutor, scheduler_worker: UnifiedSchedulerWorker, prepare_queue_owner: EngineTaskQueueOwner, + prepare_text_queue_owner: EngineTaskQueueOwner, + prepare_ref_spec_queue_owner: EngineTaskQueueOwner, finalize_queue_owner: EngineTaskQueueOwner, dispatch_queue_owner: EngineTaskQueueOwner, decode_runtime_owner: EngineDecodeRuntimeOwner, @@ -22,6 +24,8 @@ class EngineStageOrchestrator: self.executor = executor self.scheduler_worker = scheduler_worker self.prepare_queue_owner = prepare_queue_owner + self.prepare_text_queue_owner = prepare_text_queue_owner + self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner self.finalize_queue_owner = finalize_queue_owner self.dispatch_queue_owner = dispatch_queue_owner self.decode_runtime_owner = decode_runtime_owner @@ -45,7 +49,17 @@ class EngineStageOrchestrator: def peek_queue_age_ms(self, queue_name: str) -> float: if queue_name == "prepare": + return max( + self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time"), + self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time"), + self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time"), + ) + if queue_name == "prepare_audio": return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") + if queue_name == "prepare_text": + return self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time") + if queue_name == "prepare_ref_spec": + return self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time") if queue_name == "finalize": return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time") if queue_name == "decode_runtime_pending": @@ -62,6 +76,10 @@ class EngineStageOrchestrator: return True if self.prepare_queue_owner.has_items(): return True + if self.prepare_text_queue_owner.has_items(): + return True + if self.prepare_ref_spec_queue_owner.has_items(): + return True if self.finalize_queue_owner.has_items(): return True return self.dispatch_queue_owner.has_items() @@ -79,6 +97,12 @@ class EngineStageOrchestrator: executed = False if stage == "prepare": executed = self.executor.run_engine_prepare_once() + elif stage == "prepare_audio": + executed = self.executor.run_engine_prepare_audio_once() + elif stage == "prepare_text": + executed = self.executor.run_engine_prepare_text_once() + elif stage == "prepare_ref_spec": + executed = self.executor.run_engine_prepare_ref_spec_once() elif stage == "finalize": executed = self.executor.run_engine_finalize_once() elif stage == "decode_dispatch": diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py index 1b872dfa..27ed3bf5 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py @@ -24,6 +24,8 @@ class EngineStageCoordinator: tts: TTS, scheduler_worker: UnifiedSchedulerWorker, prepare_queue_owner: EngineTaskQueueOwner, + prepare_text_queue_owner: EngineTaskQueueOwner, + prepare_ref_spec_queue_owner: EngineTaskQueueOwner, finalize_queue_owner: EngineTaskQueueOwner, dispatch_queue_owner: EngineTaskQueueOwner, decode_runtime_owner: EngineDecodeRuntimeOwner, @@ -45,6 +47,8 @@ class EngineStageCoordinator: tts=tts, scheduler_worker=scheduler_worker, prepare_queue_owner=prepare_queue_owner, + prepare_text_queue_owner=prepare_text_queue_owner, + prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner, finalize_queue_owner=finalize_queue_owner, dispatch_queue_owner=dispatch_queue_owner, decode_runtime_owner=decode_runtime_owner, @@ -66,6 +70,8 @@ class EngineStageCoordinator: executor=self.executor, scheduler_worker=scheduler_worker, prepare_queue_owner=prepare_queue_owner, + prepare_text_queue_owner=prepare_text_queue_owner, + prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner, finalize_queue_owner=finalize_queue_owner, dispatch_queue_owner=dispatch_queue_owner, decode_runtime_owner=decode_runtime_owner, @@ -144,6 +150,15 @@ class EngineStageCoordinator: def run_engine_prepare_once(self) -> bool: return self.executor.run_engine_prepare_once() + def run_engine_prepare_audio_once(self) -> bool: + return self.executor.run_engine_prepare_audio_once() + + def run_engine_prepare_text_once(self) -> bool: + return self.executor.run_engine_prepare_text_once() + + def run_engine_prepare_ref_spec_once(self) -> bool: + return self.executor.run_engine_prepare_ref_spec_once() + def run_engine_finalize_once(self) -> bool: return self.executor.run_engine_finalize_once() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py index 644c35f6..f6a249fa 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py @@ -24,6 +24,10 @@ class EngineDispatchStageMixin: engine_request_id: str | None, timeout_sec: float | None, ) -> EngineDispatchTask: + if float(state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0: + error = RuntimeError("ref_spec async stage failed before dispatch") + self.fail_request_state(engine_request_id or state.request_id, str(error)) + raise error task = EngineDispatchTask( request_id=state.request_id, state=state, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py index 01921d51..6d06f0c6 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py @@ -31,6 +31,8 @@ class EngineStageExecutor( tts: TTS, scheduler_worker: UnifiedSchedulerWorker, prepare_queue_owner: EngineTaskQueueOwner, + prepare_text_queue_owner: EngineTaskQueueOwner, + prepare_ref_spec_queue_owner: EngineTaskQueueOwner, finalize_queue_owner: EngineTaskQueueOwner, dispatch_queue_owner: EngineTaskQueueOwner, decode_runtime_owner: EngineDecodeRuntimeOwner, @@ -51,6 +53,8 @@ class EngineStageExecutor( self.tts = tts self.scheduler_worker = scheduler_worker self.prepare_queue_owner = prepare_queue_owner + self.prepare_text_queue_owner = prepare_text_queue_owner + self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner self.finalize_queue_owner = finalize_queue_owner self.dispatch_queue_owner = dispatch_queue_owner self.decode_runtime_owner = decode_runtime_owner diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py index 8e66f76e..4b61993e 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py @@ -39,10 +39,37 @@ class EngineFinalizeStageMixin: tasks = self.take_engine_finalize_batch_nonblocking() if not tasks: return False - self.scheduler_worker.begin_finalize_execution(len(tasks)) + ready_tasks: List[SchedulerFinalizeTask] = [] + failed_tasks: List[SchedulerFinalizeTask] = [] + deferred_tasks: List[SchedulerFinalizeTask] = [] + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is None: + continue + if float(job.state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0: + failed_tasks.append(task) + continue + if job.state.refer_spec is None: + deferred_tasks.append(task) + self.merge_request_state_profile( + job.engine_request_id or job.request_id, + { + "engine_finalize_ref_spec_blocked": 1.0, + }, + ) + continue + ready_tasks.append(task) + if deferred_tasks: + self.finalize_queue_owner.enqueue_many(deferred_tasks) + if failed_tasks: + self.fail_engine_jobs([task.request_id for task in failed_tasks], "ref_spec async stage failed") + if not ready_tasks: + self.finalize_queue_owner.mark_completed(len(failed_tasks), notify=True) + return False + self.scheduler_worker.begin_finalize_execution(len(ready_tasks)) try: jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] - for task in tasks: + for task in ready_tasks: job = self.get_engine_job(task.request_id) if job is None: continue @@ -50,7 +77,7 @@ class EngineFinalizeStageMixin: if not jobs_and_items: return False now = time.perf_counter() - for task in tasks: + for task in ready_tasks: job = self.get_engine_job(task.request_id) if job is not None: job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) @@ -69,8 +96,8 @@ class EngineFinalizeStageMixin: for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) except Exception as exc: - self.fail_engine_jobs([task.request_id for task in tasks], str(exc)) + self.fail_engine_jobs([task.request_id for task in ready_tasks], str(exc)) finally: - self.scheduler_worker.end_finalize_execution(len(tasks)) - self.finalize_queue_owner.mark_completed(len(tasks), notify=True) + self.scheduler_worker.end_finalize_execution(len(ready_tasks)) + self.finalize_queue_owner.mark_completed(len(ready_tasks) + len(failed_tasks), notify=True) return True diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py index b9095d2c..1ea7c45b 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py @@ -10,6 +10,13 @@ from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineGpuPrepare class EnginePrepareStageMixin: + def _prepare_waiting_total(self) -> int: + return ( + int(self.prepare_queue_owner.waiting_count()) + + int(self.prepare_text_queue_owner.waiting_count()) + + int(self.prepare_ref_spec_queue_owner.waiting_count()) + ) + async def _wait_prepare_queue_admission(self) -> float: soft_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_SOFT_MAX", "0"))) if soft_max <= 0: @@ -19,7 +26,7 @@ class EnginePrepareStageMixin: float(max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_ADMISSION_POLL_MS", "1")))) / 1000.0, ) wait_start = time.perf_counter() - while self.prepare_queue_owner.waiting_count() >= soft_max: + while self._prepare_waiting_total() >= soft_max: await asyncio.sleep(poll_s) return max(0.0, (time.perf_counter() - wait_start) * 1000.0) @@ -53,44 +60,247 @@ class EnginePrepareStageMixin: done_future=done_future, engine_request_id=engine_request_id or spec.request_id, enqueue_time=time.perf_counter(), + phase="audio", + audio_enqueue_time=time.perf_counter(), admission_wait_ms=float(prepare_queue_admission_wait_ms), ) self.prepare_queue_owner.enqueue(task) self.notify_arbiter() return await done_future - def run_engine_prepare_once(self) -> bool: - prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() - tasks = self.prepare_queue_owner.pop_left_many(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) + def _should_chain_prepare_text_after_audio(self) -> bool: + if str(os.environ.get("GPTSOVITS_ENGINE_PREPARE_CHAIN_TEXT", "1")).strip().lower() in {"0", "false", "no", "off"}: + return False + if self.finalize_queue_owner.has_items() or self.dispatch_queue_owner.has_items(): + return False + decode_runtime_state = self.snapshot_engine_decode_runtime_state() + if bool(decode_runtime_state.get("has_work", False)): + return False + return True + + def _maybe_apply_ref_spec_to_state(self, task: EngineGpuPrepareTask) -> None: + if task.state_result is None or task.ref_spec_result is None: + return + self.scheduler_worker.apply_ref_spec_result_to_state(task.state_result, task.ref_spec_result) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + { + "engine_prepare_ref_spec_queue_wait_ms": float(task.ref_spec_queue_wait_ms), + "ref_spec_wait_ms": float(task.ref_spec_result[1].get("ref_spec_wait_ms", 0.0)), + "ref_spec_ms": float(task.ref_spec_result[1].get("ref_spec_ms", 0.0)), + "ref_spec_to_device_ms": float(task.ref_spec_result[1].get("ref_spec_to_device_ms", 0.0)), + "ref_spec_main_resample_ms": float(task.ref_spec_result[1].get("ref_spec_main_resample_ms", 0.0)), + "ref_spec_norm_ms": float(task.ref_spec_result[1].get("ref_spec_norm_ms", 0.0)), + "ref_spec_spectrogram_ms": float(task.ref_spec_result[1].get("ref_spec_spectrogram_ms", 0.0)), + "ref_spec_post_resample_ms": float(task.ref_spec_result[1].get("ref_spec_post_resample_ms", 0.0)), + }, + ) + + def _mark_ref_spec_async_failed( + self, + task: EngineGpuPrepareTask, + error: Exception, + *, + queue_wait_ms: float, + ) -> None: + task.error = str(error) + task.cancelled = True + if task.state_result is not None: + task.state_result.prepare_profile["ref_spec_async_failed"] = 1.0 + task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + { + "ref_spec_async_failed": 1.0, + "engine_prepare_ref_spec_queue_wait_ms": float(queue_wait_ms), + }, + ) + self.fail_request_state(task.engine_request_id or task.request_id, str(error)) + self.fail_engine_jobs([task.request_id], str(error)) + self.notify_arbiter() + + def _run_engine_prepare_audio_once(self, batch_max_items: int) -> bool: + tasks = self.prepare_queue_owner.pop_left_many(batch_max_items) if not tasks: return False now = time.perf_counter() queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks] - batch_results = asyncio.run( - self.scheduler_worker.prepare_gpu_stages_profiled_async([task.cpu_stage for task in tasks]) - ) + for task in tasks: + task.audio_start_time = float(now) + batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_audio_phases_async([task.cpu_stage for task in tasks])) completed_count = 0 for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results): + task.audio_end_time = time.perf_counter() if isinstance(result, Exception): task.error = str(result) self.fail_request_state(task.engine_request_id or task.request_id, str(result)) self._notify_prepare_error(task, result) completed_count += 1 continue - state, prepare_exec_started_at, prepare_exec_finished_at = result - state.prepare_profile["engine_prepare_queue_admission_wait_ms"] = float(task.admission_wait_ms) - state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) + task.audio_queue_wait_ms = float(queue_wait_ms) + task.phase_one = result + task.phase = "text" + task.enqueue_time = time.perf_counter() + task.text_enqueue_time = float(task.enqueue_time) + task.ref_spec_enqueue_time = float(task.enqueue_time) + self.prepare_text_queue_owner.enqueue(task) + self.prepare_ref_spec_queue_owner.enqueue(task) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + { + "engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms), + "engine_prepare_audio_queue_wait_ms": float(queue_wait_ms), + "engine_prepare_audio_batch_size": float(len(tasks)), + "engine_prepare_audio_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)), + "engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time), + "engine_prepare_audio_start_ts": float(task.audio_start_time), + "engine_prepare_audio_end_ts": float(task.audio_end_time), + "engine_prepare_text_enqueue_ts": float(task.text_enqueue_time), + "engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time), + }, + ) + completed_count += 1 + self.prepare_queue_owner.mark_completed(completed_count) + if completed_count > 0 and self._should_chain_prepare_text_after_audio(): + self._run_engine_prepare_text_once(min(batch_max_items, completed_count)) + return True + if completed_count > 0: + self.notify_arbiter() + return True + + def _run_engine_prepare_text_once(self, batch_max_items: int) -> bool: + tasks = self.prepare_text_queue_owner.pop_left_many(batch_max_items) + if not tasks: + return False + now = time.perf_counter() + queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks] + for task in tasks: + task.text_start_time = float(now) + items = [(task.cpu_stage, task.phase_one) for task in tasks if task.phase_one is not None] + batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_text_phases_async(items)) + completed_count = 0 + for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results): + task.text_end_time = time.perf_counter() + if isinstance(result, Exception): + task.error = str(result) + task.cancelled = True + self.fail_request_state(task.engine_request_id or task.request_id, str(result)) + self._notify_prepare_error(task, result) + completed_count += 1 + continue + task.text_queue_wait_ms = float(queue_wait_ms) + state, prepare_exec_started_at, prepare_exec_finished_at = self.scheduler_worker.build_gpu_prepare_result_from_phases( + task.cpu_stage, + task.phase_one or {}, + result, + extra_profile={ + "engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms), + "engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms), + "engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms), + "engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms), + "engine_prepare_audio_batch_size": float(len(tasks)), + "engine_prepare_text_batch_size": float(len(tasks)), + "engine_prepare_audio_phase_mode": 2.0, + "engine_prepare_audio_phase_wall_ms": float((task.phase_one or {}).get("phase_wall_ms", 0.0)), + "engine_prepare_text_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)), + "engine_prepare_text_phase_batch_size": float(len(tasks)), + "engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time), + "engine_prepare_audio_start_ts": float(task.audio_start_time), + "engine_prepare_audio_end_ts": float(task.audio_end_time), + "engine_prepare_text_enqueue_ts": float(task.text_enqueue_time), + "engine_prepare_text_start_ts": float(task.text_start_time), + "engine_prepare_text_end_ts": float(task.text_end_time), + "engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time), + }, + ) + task.state_result = state + self._maybe_apply_ref_spec_to_state(task) state.prepare_profile["engine_gpu_prepare_batch_size"] = float(len(tasks)) if task.engine_request_id not in [None, ""]: self.merge_request_state_profile( str(task.engine_request_id), { "engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms), - "engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms), + "engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms), + "engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms), + "engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms), "engine_gpu_prepare_batch_size": float(len(tasks)), }, ) self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) completed_count += 1 - self.prepare_queue_owner.mark_completed(completed_count) + self.prepare_text_queue_owner.mark_completed(completed_count) return True + + def _run_engine_prepare_ref_spec_once(self, batch_max_items: int) -> bool: + tasks = self.prepare_ref_spec_queue_owner.pop_left_many(batch_max_items) + if not tasks: + return False + now = time.perf_counter() + runnable_tasks: list[EngineGpuPrepareTask] = [] + queue_wait_ms_list: list[float] = [] + completed_count = 0 + for task in tasks: + if task.cancelled or task.phase_one is None: + completed_count += 1 + continue + task.ref_spec_start_time = float(now) + runnable_tasks.append(task) + queue_wait_ms_list.append(max(0.0, (now - task.ref_spec_enqueue_time) * 1000.0)) + if not runnable_tasks: + self.prepare_ref_spec_queue_owner.mark_completed(completed_count) + return True + batch_results = asyncio.run( + self.scheduler_worker.prepare_ref_spec_stages_async([task.phase_one or {} for task in runnable_tasks]) + ) + for task, queue_wait_ms, result in zip(runnable_tasks, queue_wait_ms_list, batch_results): + task.ref_spec_end_time = time.perf_counter() + task.ref_spec_queue_wait_ms = float(queue_wait_ms) + if isinstance(result, Exception): + self._mark_ref_spec_async_failed(task, result, queue_wait_ms=float(queue_wait_ms)) + completed_count += 1 + continue + task.ref_spec_result = result + self._maybe_apply_ref_spec_to_state(task) + if task.state_result is not None: + task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms) + task.state_result.prepare_profile["engine_prepare_ref_spec_enqueue_ts"] = float(task.ref_spec_enqueue_time) + task.state_result.prepare_profile["engine_prepare_ref_spec_start_ts"] = float(task.ref_spec_start_time) + task.state_result.prepare_profile["engine_prepare_ref_spec_end_ts"] = float(task.ref_spec_end_time) + completed_count += 1 + self.prepare_ref_spec_queue_owner.mark_completed(completed_count) + return True + + def run_engine_prepare_once(self) -> bool: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + batch_max_items = int(prepare_batch_policy.get("prepare_batch_max_items", 1)) + audio_age_ms = self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") + text_age_ms = self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time") + if self.prepare_text_queue_owner.has_items() and ( + not self.prepare_queue_owner.has_items() or text_age_ms >= audio_age_ms + ): + return self._run_engine_prepare_text_once(batch_max_items) + if self.prepare_queue_owner.has_items(): + return self._run_engine_prepare_audio_once(batch_max_items) + if self.prepare_ref_spec_queue_owner.has_items(): + return self._run_engine_prepare_ref_spec_once(batch_max_items) + if self.prepare_text_queue_owner.has_items(): + return self._run_engine_prepare_text_once(batch_max_items) + if self.prepare_ref_spec_queue_owner.has_items(): + return self._run_engine_prepare_ref_spec_once(batch_max_items) + return False + + def run_engine_prepare_audio_once(self) -> bool: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + return self._run_engine_prepare_audio_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) + + def run_engine_prepare_text_once(self) -> bool: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + return self._run_engine_prepare_text_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) + + def run_engine_prepare_ref_spec_once(self) -> bool: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + return self._run_engine_prepare_ref_spec_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py index 3a675cbe..bb9cb3cb 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py @@ -151,7 +151,9 @@ class WorkerFinalizeExecutor: @staticmethod def _collect_job_refer_specs(job: SchedulerPendingJob) -> List[tuple]: - refer_specs = [job.state.refer_spec] + refer_specs = [] + if job.state.refer_spec is not None: + refer_specs.append(job.state.refer_spec) refer_specs.extend(list(getattr(job.state, "aux_refer_specs", []) or [])) return refer_specs diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py index 8b3db1fa..9fb7c8d9 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio import os import time -from typing import Callable, Dict, List +from typing import Any, Callable, Dict, List from GPT_SoVITS.TTS_infer_pack.TTS import TTS from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage @@ -81,11 +81,60 @@ class WorkerPrepareExecutor: cpu_stages: List[PreparedCpuStage], ) -> List[tuple[T2SRequestState, float, float] | Exception]: try: - return list( - await asyncio.gather( - *[self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) for cpu_stage in cpu_stages], - return_exceptions=True, - ) + return await self.coordinator.prepare_gpu_stages_profiled_async(cpu_stages) + finally: + self._notify_state_change() + + async def prepare_gpu_audio_phases_async( + self, + cpu_stages: List[PreparedCpuStage], + ) -> List[Dict[str, Any] | Exception]: + try: + return await self.coordinator.prepare_gpu_audio_phases_async(cpu_stages) + finally: + self._notify_state_change() + + async def prepare_gpu_text_phases_async( + self, + items: List[tuple[PreparedCpuStage, Dict[str, Any]]], + ) -> List[Dict[str, Any] | Exception]: + try: + return await self.coordinator.prepare_gpu_text_phases_async(items) + finally: + self._notify_state_change() + + def build_gpu_prepare_result_from_phases( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + phase_two: Dict[str, Any], + extra_profile: Dict[str, float] | None = None, + ) -> tuple[T2SRequestState, float, float]: + try: + return self.coordinator.build_gpu_prepare_result_from_phases( + cpu_stage, + phase_one, + phase_two, + extra_profile=extra_profile, ) finally: self._notify_state_change() + + async def prepare_ref_spec_stages_async( + self, + phase_ones: List[Dict[str, Any]], + ) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]: + try: + return await self.coordinator.prepare_ref_spec_stages_async(phase_ones) + finally: + self._notify_state_change() + + def apply_ref_spec_result_to_state( + self, + state: T2SRequestState, + ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]], + ) -> None: + try: + self.coordinator.apply_ref_spec_result_to_state(state, ref_spec_result) + finally: + self._notify_state_change() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py index e498e9ea..2ac636fe 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py @@ -267,3 +267,42 @@ class WorkerSubmitLifecycleMixin: cpu_stages: List[PreparedCpuStage], ) -> List[tuple[T2SRequestState, float, float] | Exception]: return await self.prepare_executor.prepare_gpu_stages_profiled_async(cpu_stages) + + async def prepare_gpu_audio_phases_async( + self, + cpu_stages: List[PreparedCpuStage], + ) -> List[Dict[str, Any] | Exception]: + return await self.prepare_executor.prepare_gpu_audio_phases_async(cpu_stages) + + async def prepare_gpu_text_phases_async( + self, + items: List[tuple[PreparedCpuStage, Dict[str, Any]]], + ) -> List[Dict[str, Any] | Exception]: + return await self.prepare_executor.prepare_gpu_text_phases_async(items) + + def build_gpu_prepare_result_from_phases( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + phase_two: Dict[str, Any], + extra_profile: Dict[str, float] | None = None, + ) -> tuple[T2SRequestState, float, float]: + return self.prepare_executor.build_gpu_prepare_result_from_phases( + cpu_stage, + phase_one, + phase_two, + extra_profile=extra_profile, + ) + + async def prepare_ref_spec_stages_async( + self, + phase_ones: List[Dict[str, Any]], + ) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]: + return await self.prepare_executor.prepare_ref_spec_stages_async(phase_ones) + + def apply_ref_spec_result_to_state( + self, + state: T2SRequestState, + ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]], + ) -> None: + self.prepare_executor.apply_ref_spec_result_to_state(state, ref_spec_result) diff --git a/GPT_SoVITS/text/g2pw/cuda_api.py b/GPT_SoVITS/text/g2pw/cuda_api.py index e1a84748..881d6123 100644 --- a/GPT_SoVITS/text/g2pw/cuda_api.py +++ b/GPT_SoVITS/text/g2pw/cuda_api.py @@ -244,6 +244,16 @@ class G2PWRuntimeWrapper: ) self.batch_worker.start() + def _sync_runtime_env_overrides(self) -> None: + os.environ["G2PW_ENABLE_CUDA_GRAPH"] = "1" if self.enable_cuda_graph else "0" + os.environ["G2PW_ENABLE_PROFILE"] = "1" if self.enable_profiling else "0" + os.environ["G2PW_DUMP_GRAPH_CACHE_STATS"] = "1" if self.dump_graph_cache_stats else "0" + os.environ["G2PW_FULL_GRAPH_CACHE_LIMIT"] = str(int(self.full_graph_cache_limit)) + os.environ["G2PW_TAIL_GRAPH_CACHE_LIMIT"] = str(int(self.tail_graph_cache_limit)) + os.environ["G2PW_ALLOW_TENSOR_CORES"] = "1" if self.allow_tensor_cores else "0" + os.environ["G2PW_USE_CUBLASLT_BIAS_EPILOGUE"] = "1" if self.use_cublaslt_bias_epilogue else "0" + os.environ["G2PW_GEMM_PRECISION"] = {0: "fp32", 1: "fp16", 2: "bf16"}.get(int(self.gemm_precision), "fp32") + def _destroy_handle(self) -> None: if self.handle: self.lib.g2pw_runtime_destroy(self.handle) @@ -268,6 +278,7 @@ class G2PWRuntimeWrapper: return "" if not message else message.decode("utf-8", errors="replace") def _create_handle(self, batch_size: int, seq_len: int) -> None: + self._sync_runtime_env_overrides() new_handle = self.lib.g2pw_runtime_create( str(self.manifest_path).encode("utf-8"), str(self.weights_path).encode("utf-8"), @@ -518,6 +529,10 @@ class G2PWRuntimeWrapper: return { "shard_index": int(self.shard_index), "enabled": bool(self.batch_enabled), + "enable_cuda_graph": bool(self.enable_cuda_graph), + "enable_profiling": bool(self.enable_profiling), + "full_graph_cache_limit": int(self.full_graph_cache_limit), + "tail_graph_cache_limit": int(self.tail_graph_cache_limit), "window_ms": float(self.batch_window_s * 1000.0), "max_requests": int(self.batch_max_requests), "max_rows": int(self.batch_max_rows),