diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 81c1ca1e..0140eff3 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -945,6 +945,13 @@ class TTS: codes = self.vits_model.extract_latent(hubert_feature) return codes[0, 0].to(self.configs.device) + @torch.inference_mode() + def _extract_prompt_semantic_profile_from_prepared_wav16k(self, wav16k: torch.Tensor): + forward_start = time.perf_counter() + prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + return prompt_semantic, forward_ms + @torch.inference_mode() def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): cpu_prepare_start = time.perf_counter() @@ -954,9 +961,7 @@ class TTS: zero_wav_samples=int(self.configs.sampling_rate * 0.3), ) cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0 - forward_start = time.perf_counter() - prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k) - forward_ms = (time.perf_counter() - forward_start) * 1000.0 + prompt_semantic, forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k) return prompt_semantic, cpu_prepare_ms, forward_ms @torch.inference_mode() @@ -1011,10 +1016,17 @@ class TTS: raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) load_ms = (time.perf_counter() - load_start) * 1000.0 if self.prepare_ref_semantic_batch_worker is None: + prompt_semantic_cpu_prepare_start = time.perf_counter() + wav16k = prepare_prompt_semantic_wav16k( + raw_audio=raw_audio, + raw_sr=raw_sr, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + ) + prompt_semantic_cpu_prepare_ms = (time.perf_counter() - prompt_semantic_cpu_prepare_start) * 1000.0 with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats: prompt_semantic_start = time.perf_counter() - prompt_semantic, prompt_semantic_cpu_prepare_ms, prompt_semantic_forward_ms = ( - self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr) + prompt_semantic, prompt_semantic_forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k( + wav16k ) prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 ref_spec_start = time.perf_counter() @@ -1025,6 +1037,10 @@ class TTS: audio_stage_inflight_peak = float(limiter_stats["peak_inflight"]) prompt_semantic_profile = { "prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]), + "prompt_semantic_worker_queue_wait_ms": 0.0, + "prompt_semantic_batch_collect_wait_ms": 0.0, + "prompt_semantic_stage_limiter_wait_ms": float(limiter_stats["wait_ms"]), + "prompt_semantic_batch_dispatch_delay_ms": 0.0, "prompt_semantic_cpu_prepare_ms": float(prompt_semantic_cpu_prepare_ms), "prompt_semantic_forward_ms": float(prompt_semantic_forward_ms), "prompt_semantic_scatter_ms": 0.0, @@ -1046,6 +1062,18 @@ class TTS: "audio_stage_inflight_peak": audio_stage_inflight_peak, "prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_worker_queue_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) + ), + "prompt_semantic_batch_collect_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0) + ), + "prompt_semantic_stage_limiter_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0) + ), + "prompt_semantic_batch_dispatch_delay_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) + ), "prompt_semantic_cpu_prepare_ms": float( prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) ), @@ -1073,6 +1101,10 @@ class TTS: prompt_semantic_profile = { "prompt_semantic_wait_ms": 0.0, + "prompt_semantic_worker_queue_wait_ms": 0.0, + "prompt_semantic_batch_collect_wait_ms": 0.0, + "prompt_semantic_stage_limiter_wait_ms": 0.0, + "prompt_semantic_batch_dispatch_delay_ms": 0.0, "prompt_semantic_cpu_prepare_ms": 0.0, "prompt_semantic_forward_ms": 0.0, "prompt_semantic_scatter_ms": 0.0, @@ -1116,6 +1148,18 @@ class TTS: "audio_stage_inflight_peak": audio_stage_inflight_peak, "prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_worker_queue_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) + ), + "prompt_semantic_batch_collect_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0) + ), + "prompt_semantic_stage_limiter_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0) + ), + "prompt_semantic_batch_dispatch_delay_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) + ), "prompt_semantic_cpu_prepare_ms": float( prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) ), @@ -1193,6 +1237,18 @@ class TTS: "audio_stage_inflight_peak": float(audio_stage_inflight_peak), "prompt_semantic_ms": float(prompt_semantic_ms), "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_worker_queue_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) + ), + "prompt_semantic_batch_collect_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0) + ), + "prompt_semantic_stage_limiter_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0) + ), + "prompt_semantic_batch_dispatch_delay_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) + ), "prompt_semantic_cpu_prepare_ms": float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py index 64ca133f..d46352a7 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py @@ -51,6 +51,7 @@ class RefSemanticTask: raw_sr: int task_id: str = field(default_factory=lambda: uuid.uuid4().hex) created_at: float = field(default_factory=time.perf_counter) + batch_popped_at: float = 0.0 done_event: threading.Event = field(default_factory=threading.Event) done_loop: asyncio.AbstractEventLoop | None = None done_future: asyncio.Future | None = None @@ -170,12 +171,14 @@ class PrepareRefSemanticBatchWorker: "max_batch_samples": self.max_batch_samples, } - def _collect_batch(self) -> List[RefSemanticTask]: + def _collect_batch(self) -> tuple[List[RefSemanticTask], float]: with self.condition: while not self.pending_tasks: self.condition.wait() - batch: List[RefSemanticTask] = [self.pending_tasks.popleft()] + first_task = self.pending_tasks.popleft() + first_task.batch_popped_at = time.perf_counter() + batch: List[RefSemanticTask] = [first_task] batch_samples = self._estimate_task_samples(batch[0]) deadline = time.perf_counter() + self.batch_window_s @@ -190,7 +193,9 @@ class PrepareRefSemanticBatchWorker: next_samples = self._estimate_task_samples(next_task) if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples: break - batch.append(self.pending_tasks.popleft()) + popped_task = self.pending_tasks.popleft() + popped_task.batch_popped_at = time.perf_counter() + batch.append(popped_task) batch_samples += next_samples self.active_batch_size = len(batch) @@ -199,7 +204,7 @@ class PrepareRefSemanticBatchWorker: self.active_batch_peak = self.active_batch_size if self.active_batch_samples > self.active_batch_samples_peak: self.active_batch_samples_peak = self.active_batch_samples - return batch + return batch, time.perf_counter() def _finalize_batch(self, batch: List[RefSemanticTask]) -> None: with self.condition: @@ -219,7 +224,7 @@ class PrepareRefSemanticBatchWorker: return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device) @torch.inference_mode() - def _run_batch(self, batch: List[RefSemanticTask]) -> None: + def _run_batch(self, batch: List[RefSemanticTask], batch_collected_at: float) -> None: batch_started = time.perf_counter() prepared_start = time.perf_counter() prepared_wavs = [ @@ -268,8 +273,19 @@ class PrepareRefSemanticBatchWorker: try: code_len = int(code_lengths[batch_index].item()) task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone() + worker_queue_wait_ms = max(0.0, (float(task.batch_popped_at) - float(task.created_at)) * 1000.0) + batch_collect_wait_ms = max(0.0, (float(batch_collected_at) - float(task.batch_popped_at)) * 1000.0) + stage_limiter_wait_ms = float(limiter_stats["wait_ms"]) task.profile = { - "prompt_semantic_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]), + "prompt_semantic_wait_ms": worker_queue_wait_ms + + batch_collect_wait_ms + + stage_limiter_wait_ms, + "prompt_semantic_worker_queue_wait_ms": worker_queue_wait_ms, + "prompt_semantic_batch_collect_wait_ms": batch_collect_wait_ms, + "prompt_semantic_stage_limiter_wait_ms": stage_limiter_wait_ms, + "prompt_semantic_batch_dispatch_delay_ms": max( + 0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0 + ), "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), "prompt_semantic_forward_ms": float(forward_ms), "prompt_semantic_scatter_ms": 0.0, @@ -289,9 +305,9 @@ class PrepareRefSemanticBatchWorker: def _run_loop(self) -> None: while True: - batch = self._collect_batch() + batch, batch_collected_at = self._collect_batch() try: - self._run_batch(batch) + self._run_batch(batch, batch_collected_at) except Exception as exc: # noqa: PERF203 for task in batch: task.error = exc diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index e993a1ef..43290af7 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -305,6 +305,18 @@ def build_request_state_from_parts( "audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)), "prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_worker_queue_wait_ms": float( + bundle_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) + ), + "prompt_semantic_batch_collect_wait_ms": float( + bundle_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0) + ), + "prompt_semantic_stage_limiter_wait_ms": float( + bundle_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0) + ), + "prompt_semantic_batch_dispatch_delay_ms": float( + bundle_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) + ), "prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), "prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)), "prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)),