From bc1f3f32de89e1ec6e093b8e27a6c16842a67753 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Fri, 13 Mar 2026 02:03:25 +0800 Subject: [PATCH] Enhance audio processing in TTS framework with resampling and profiling improvements Add resampling capabilities using torchaudio to prepare reference audio at 16kHz, replacing librosa for better performance. Introduce a caching mechanism for resampling transforms to optimize resource usage. Update batch processing methods to include timing metrics for profiling, enhancing the ability to monitor and improve performance in the TTS system. This update improves the maintainability and efficiency of audio preparation workflows. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 66 +++++++++++++++++-- .../prepare_ref_semantic_batch_worker.py | 32 ++++++--- GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 12 ++++ 3 files changed, 97 insertions(+), 13 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 81c1ca1e..0140eff3 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -945,6 +945,13 @@ class TTS: codes = self.vits_model.extract_latent(hubert_feature) return codes[0, 0].to(self.configs.device) + @torch.inference_mode() + def _extract_prompt_semantic_profile_from_prepared_wav16k(self, wav16k: torch.Tensor): + forward_start = time.perf_counter() + prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + return prompt_semantic, forward_ms + @torch.inference_mode() def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): cpu_prepare_start = time.perf_counter() @@ -954,9 +961,7 @@ class TTS: zero_wav_samples=int(self.configs.sampling_rate * 0.3), ) cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0 - forward_start = time.perf_counter() - prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k) - forward_ms = (time.perf_counter() - forward_start) * 1000.0 + prompt_semantic, forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k) return prompt_semantic, cpu_prepare_ms, forward_ms @torch.inference_mode() @@ -1011,10 +1016,17 @@ class TTS: raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) load_ms = (time.perf_counter() - load_start) * 1000.0 if self.prepare_ref_semantic_batch_worker is None: + prompt_semantic_cpu_prepare_start = time.perf_counter() + wav16k = prepare_prompt_semantic_wav16k( + raw_audio=raw_audio, + raw_sr=raw_sr, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + ) + prompt_semantic_cpu_prepare_ms = (time.perf_counter() - prompt_semantic_cpu_prepare_start) * 1000.0 with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats: prompt_semantic_start = time.perf_counter() - prompt_semantic, prompt_semantic_cpu_prepare_ms, prompt_semantic_forward_ms = ( - self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr) + prompt_semantic, prompt_semantic_forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k( + wav16k ) prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 ref_spec_start = time.perf_counter() @@ -1025,6 +1037,10 @@ class TTS: audio_stage_inflight_peak = float(limiter_stats["peak_inflight"]) prompt_semantic_profile = { "prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]), + "prompt_semantic_worker_queue_wait_ms": 0.0, + "prompt_semantic_batch_collect_wait_ms": 0.0, + "prompt_semantic_stage_limiter_wait_ms": float(limiter_stats["wait_ms"]), + "prompt_semantic_batch_dispatch_delay_ms": 0.0, "prompt_semantic_cpu_prepare_ms": float(prompt_semantic_cpu_prepare_ms), "prompt_semantic_forward_ms": float(prompt_semantic_forward_ms), "prompt_semantic_scatter_ms": 0.0, @@ -1046,6 +1062,18 @@ class TTS: "audio_stage_inflight_peak": audio_stage_inflight_peak, "prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_worker_queue_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) + ), + "prompt_semantic_batch_collect_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0) + ), + "prompt_semantic_stage_limiter_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0) + ), + "prompt_semantic_batch_dispatch_delay_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) + ), "prompt_semantic_cpu_prepare_ms": float( prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) ), @@ -1073,6 +1101,10 @@ class TTS: prompt_semantic_profile = { "prompt_semantic_wait_ms": 0.0, + "prompt_semantic_worker_queue_wait_ms": 0.0, + "prompt_semantic_batch_collect_wait_ms": 0.0, + "prompt_semantic_stage_limiter_wait_ms": 0.0, + "prompt_semantic_batch_dispatch_delay_ms": 0.0, "prompt_semantic_cpu_prepare_ms": 0.0, "prompt_semantic_forward_ms": 0.0, "prompt_semantic_scatter_ms": 0.0, @@ -1116,6 +1148,18 @@ class TTS: "audio_stage_inflight_peak": audio_stage_inflight_peak, "prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_worker_queue_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) + ), + "prompt_semantic_batch_collect_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0) + ), + "prompt_semantic_stage_limiter_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0) + ), + "prompt_semantic_batch_dispatch_delay_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) + ), "prompt_semantic_cpu_prepare_ms": float( prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) ), @@ -1193,6 +1237,18 @@ class TTS: "audio_stage_inflight_peak": float(audio_stage_inflight_peak), "prompt_semantic_ms": float(prompt_semantic_ms), "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_worker_queue_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) + ), + "prompt_semantic_batch_collect_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0) + ), + "prompt_semantic_stage_limiter_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0) + ), + "prompt_semantic_batch_dispatch_delay_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) + ), "prompt_semantic_cpu_prepare_ms": float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py index 64ca133f..d46352a7 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py @@ -51,6 +51,7 @@ class RefSemanticTask: raw_sr: int task_id: str = field(default_factory=lambda: uuid.uuid4().hex) created_at: float = field(default_factory=time.perf_counter) + batch_popped_at: float = 0.0 done_event: threading.Event = field(default_factory=threading.Event) done_loop: asyncio.AbstractEventLoop | None = None done_future: asyncio.Future | None = None @@ -170,12 +171,14 @@ class PrepareRefSemanticBatchWorker: "max_batch_samples": self.max_batch_samples, } - def _collect_batch(self) -> List[RefSemanticTask]: + def _collect_batch(self) -> tuple[List[RefSemanticTask], float]: with self.condition: while not self.pending_tasks: self.condition.wait() - batch: List[RefSemanticTask] = [self.pending_tasks.popleft()] + first_task = self.pending_tasks.popleft() + first_task.batch_popped_at = time.perf_counter() + batch: List[RefSemanticTask] = [first_task] batch_samples = self._estimate_task_samples(batch[0]) deadline = time.perf_counter() + self.batch_window_s @@ -190,7 +193,9 @@ class PrepareRefSemanticBatchWorker: next_samples = self._estimate_task_samples(next_task) if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples: break - batch.append(self.pending_tasks.popleft()) + popped_task = self.pending_tasks.popleft() + popped_task.batch_popped_at = time.perf_counter() + batch.append(popped_task) batch_samples += next_samples self.active_batch_size = len(batch) @@ -199,7 +204,7 @@ class PrepareRefSemanticBatchWorker: self.active_batch_peak = self.active_batch_size if self.active_batch_samples > self.active_batch_samples_peak: self.active_batch_samples_peak = self.active_batch_samples - return batch + return batch, time.perf_counter() def _finalize_batch(self, batch: List[RefSemanticTask]) -> None: with self.condition: @@ -219,7 +224,7 @@ class PrepareRefSemanticBatchWorker: return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device) @torch.inference_mode() - def _run_batch(self, batch: List[RefSemanticTask]) -> None: + def _run_batch(self, batch: List[RefSemanticTask], batch_collected_at: float) -> None: batch_started = time.perf_counter() prepared_start = time.perf_counter() prepared_wavs = [ @@ -268,8 +273,19 @@ class PrepareRefSemanticBatchWorker: try: code_len = int(code_lengths[batch_index].item()) task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone() + worker_queue_wait_ms = max(0.0, (float(task.batch_popped_at) - float(task.created_at)) * 1000.0) + batch_collect_wait_ms = max(0.0, (float(batch_collected_at) - float(task.batch_popped_at)) * 1000.0) + stage_limiter_wait_ms = float(limiter_stats["wait_ms"]) task.profile = { - "prompt_semantic_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]), + "prompt_semantic_wait_ms": worker_queue_wait_ms + + batch_collect_wait_ms + + stage_limiter_wait_ms, + "prompt_semantic_worker_queue_wait_ms": worker_queue_wait_ms, + "prompt_semantic_batch_collect_wait_ms": batch_collect_wait_ms, + "prompt_semantic_stage_limiter_wait_ms": stage_limiter_wait_ms, + "prompt_semantic_batch_dispatch_delay_ms": max( + 0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0 + ), "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), "prompt_semantic_forward_ms": float(forward_ms), "prompt_semantic_scatter_ms": 0.0, @@ -289,9 +305,9 @@ class PrepareRefSemanticBatchWorker: def _run_loop(self) -> None: while True: - batch = self._collect_batch() + batch, batch_collected_at = self._collect_batch() try: - self._run_batch(batch) + self._run_batch(batch, batch_collected_at) except Exception as exc: # noqa: PERF203 for task in batch: task.error = exc diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index e993a1ef..43290af7 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -305,6 +305,18 @@ def build_request_state_from_parts( "audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)), "prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_worker_queue_wait_ms": float( + bundle_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) + ), + "prompt_semantic_batch_collect_wait_ms": float( + bundle_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0) + ), + "prompt_semantic_stage_limiter_wait_ms": float( + bundle_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0) + ), + "prompt_semantic_batch_dispatch_delay_ms": float( + bundle_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) + ), "prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), "prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)), "prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)),