Enhance audio processing in TTS framework with resampling and profiling improvements

Add resampling capabilities using torchaudio to prepare reference audio at 16kHz, replacing librosa for better performance. Introduce a caching mechanism for resampling transforms to optimize resource usage. Update batch processing methods to include timing metrics for profiling, enhancing the ability to monitor and improve performance in the TTS system. This update improves the maintainability and efficiency of audio preparation workflows.
This commit is contained in:
baicai-1145 2026-03-13 02:03:25 +08:00
parent 17cb2e5acf
commit bc1f3f32de
3 changed files with 97 additions and 13 deletions

View File

@ -945,6 +945,13 @@ class TTS:
codes = self.vits_model.extract_latent(hubert_feature) codes = self.vits_model.extract_latent(hubert_feature)
return codes[0, 0].to(self.configs.device) return codes[0, 0].to(self.configs.device)
@torch.inference_mode()
def _extract_prompt_semantic_profile_from_prepared_wav16k(self, wav16k: torch.Tensor):
forward_start = time.perf_counter()
prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
return prompt_semantic, forward_ms
@torch.inference_mode() @torch.inference_mode()
def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
cpu_prepare_start = time.perf_counter() cpu_prepare_start = time.perf_counter()
@ -954,9 +961,7 @@ class TTS:
zero_wav_samples=int(self.configs.sampling_rate * 0.3), zero_wav_samples=int(self.configs.sampling_rate * 0.3),
) )
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0 cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
forward_start = time.perf_counter() prompt_semantic, forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k)
prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
return prompt_semantic, cpu_prepare_ms, forward_ms return prompt_semantic, cpu_prepare_ms, forward_ms
@torch.inference_mode() @torch.inference_mode()
@ -1011,10 +1016,17 @@ class TTS:
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
load_ms = (time.perf_counter() - load_start) * 1000.0 load_ms = (time.perf_counter() - load_start) * 1000.0
if self.prepare_ref_semantic_batch_worker is None: if self.prepare_ref_semantic_batch_worker is None:
prompt_semantic_cpu_prepare_start = time.perf_counter()
wav16k = prepare_prompt_semantic_wav16k(
raw_audio=raw_audio,
raw_sr=raw_sr,
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
)
prompt_semantic_cpu_prepare_ms = (time.perf_counter() - prompt_semantic_cpu_prepare_start) * 1000.0
with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats: with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats:
prompt_semantic_start = time.perf_counter() prompt_semantic_start = time.perf_counter()
prompt_semantic, prompt_semantic_cpu_prepare_ms, prompt_semantic_forward_ms = ( prompt_semantic, prompt_semantic_forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(
self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr) wav16k
) )
prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0
ref_spec_start = time.perf_counter() ref_spec_start = time.perf_counter()
@ -1025,6 +1037,10 @@ class TTS:
audio_stage_inflight_peak = float(limiter_stats["peak_inflight"]) audio_stage_inflight_peak = float(limiter_stats["peak_inflight"])
prompt_semantic_profile = { prompt_semantic_profile = {
"prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]), "prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]),
"prompt_semantic_worker_queue_wait_ms": 0.0,
"prompt_semantic_batch_collect_wait_ms": 0.0,
"prompt_semantic_stage_limiter_wait_ms": float(limiter_stats["wait_ms"]),
"prompt_semantic_batch_dispatch_delay_ms": 0.0,
"prompt_semantic_cpu_prepare_ms": float(prompt_semantic_cpu_prepare_ms), "prompt_semantic_cpu_prepare_ms": float(prompt_semantic_cpu_prepare_ms),
"prompt_semantic_forward_ms": float(prompt_semantic_forward_ms), "prompt_semantic_forward_ms": float(prompt_semantic_forward_ms),
"prompt_semantic_scatter_ms": 0.0, "prompt_semantic_scatter_ms": 0.0,
@ -1046,6 +1062,18 @@ class TTS:
"audio_stage_inflight_peak": audio_stage_inflight_peak, "audio_stage_inflight_peak": audio_stage_inflight_peak,
"prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_worker_queue_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
),
"prompt_semantic_batch_collect_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
),
"prompt_semantic_stage_limiter_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
),
"prompt_semantic_batch_dispatch_delay_ms": float(
prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
),
"prompt_semantic_cpu_prepare_ms": float( "prompt_semantic_cpu_prepare_ms": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
), ),
@ -1073,6 +1101,10 @@ class TTS:
prompt_semantic_profile = { prompt_semantic_profile = {
"prompt_semantic_wait_ms": 0.0, "prompt_semantic_wait_ms": 0.0,
"prompt_semantic_worker_queue_wait_ms": 0.0,
"prompt_semantic_batch_collect_wait_ms": 0.0,
"prompt_semantic_stage_limiter_wait_ms": 0.0,
"prompt_semantic_batch_dispatch_delay_ms": 0.0,
"prompt_semantic_cpu_prepare_ms": 0.0, "prompt_semantic_cpu_prepare_ms": 0.0,
"prompt_semantic_forward_ms": 0.0, "prompt_semantic_forward_ms": 0.0,
"prompt_semantic_scatter_ms": 0.0, "prompt_semantic_scatter_ms": 0.0,
@ -1116,6 +1148,18 @@ class TTS:
"audio_stage_inflight_peak": audio_stage_inflight_peak, "audio_stage_inflight_peak": audio_stage_inflight_peak,
"prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_worker_queue_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
),
"prompt_semantic_batch_collect_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
),
"prompt_semantic_stage_limiter_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
),
"prompt_semantic_batch_dispatch_delay_ms": float(
prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
),
"prompt_semantic_cpu_prepare_ms": float( "prompt_semantic_cpu_prepare_ms": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
), ),
@ -1193,6 +1237,18 @@ class TTS:
"audio_stage_inflight_peak": float(audio_stage_inflight_peak), "audio_stage_inflight_peak": float(audio_stage_inflight_peak),
"prompt_semantic_ms": float(prompt_semantic_ms), "prompt_semantic_ms": float(prompt_semantic_ms),
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_worker_queue_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
),
"prompt_semantic_batch_collect_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
),
"prompt_semantic_stage_limiter_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
),
"prompt_semantic_batch_dispatch_delay_ms": float(
prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
),
"prompt_semantic_cpu_prepare_ms": float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), "prompt_semantic_cpu_prepare_ms": float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),

View File

@ -51,6 +51,7 @@ class RefSemanticTask:
raw_sr: int raw_sr: int
task_id: str = field(default_factory=lambda: uuid.uuid4().hex) task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
created_at: float = field(default_factory=time.perf_counter) created_at: float = field(default_factory=time.perf_counter)
batch_popped_at: float = 0.0
done_event: threading.Event = field(default_factory=threading.Event) done_event: threading.Event = field(default_factory=threading.Event)
done_loop: asyncio.AbstractEventLoop | None = None done_loop: asyncio.AbstractEventLoop | None = None
done_future: asyncio.Future | None = None done_future: asyncio.Future | None = None
@ -170,12 +171,14 @@ class PrepareRefSemanticBatchWorker:
"max_batch_samples": self.max_batch_samples, "max_batch_samples": self.max_batch_samples,
} }
def _collect_batch(self) -> List[RefSemanticTask]: def _collect_batch(self) -> tuple[List[RefSemanticTask], float]:
with self.condition: with self.condition:
while not self.pending_tasks: while not self.pending_tasks:
self.condition.wait() self.condition.wait()
batch: List[RefSemanticTask] = [self.pending_tasks.popleft()] first_task = self.pending_tasks.popleft()
first_task.batch_popped_at = time.perf_counter()
batch: List[RefSemanticTask] = [first_task]
batch_samples = self._estimate_task_samples(batch[0]) batch_samples = self._estimate_task_samples(batch[0])
deadline = time.perf_counter() + self.batch_window_s deadline = time.perf_counter() + self.batch_window_s
@ -190,7 +193,9 @@ class PrepareRefSemanticBatchWorker:
next_samples = self._estimate_task_samples(next_task) next_samples = self._estimate_task_samples(next_task)
if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples: if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples:
break break
batch.append(self.pending_tasks.popleft()) popped_task = self.pending_tasks.popleft()
popped_task.batch_popped_at = time.perf_counter()
batch.append(popped_task)
batch_samples += next_samples batch_samples += next_samples
self.active_batch_size = len(batch) self.active_batch_size = len(batch)
@ -199,7 +204,7 @@ class PrepareRefSemanticBatchWorker:
self.active_batch_peak = self.active_batch_size self.active_batch_peak = self.active_batch_size
if self.active_batch_samples > self.active_batch_samples_peak: if self.active_batch_samples > self.active_batch_samples_peak:
self.active_batch_samples_peak = self.active_batch_samples self.active_batch_samples_peak = self.active_batch_samples
return batch return batch, time.perf_counter()
def _finalize_batch(self, batch: List[RefSemanticTask]) -> None: def _finalize_batch(self, batch: List[RefSemanticTask]) -> None:
with self.condition: with self.condition:
@ -219,7 +224,7 @@ class PrepareRefSemanticBatchWorker:
return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device) return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device)
@torch.inference_mode() @torch.inference_mode()
def _run_batch(self, batch: List[RefSemanticTask]) -> None: def _run_batch(self, batch: List[RefSemanticTask], batch_collected_at: float) -> None:
batch_started = time.perf_counter() batch_started = time.perf_counter()
prepared_start = time.perf_counter() prepared_start = time.perf_counter()
prepared_wavs = [ prepared_wavs = [
@ -268,8 +273,19 @@ class PrepareRefSemanticBatchWorker:
try: try:
code_len = int(code_lengths[batch_index].item()) code_len = int(code_lengths[batch_index].item())
task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone() task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone()
worker_queue_wait_ms = max(0.0, (float(task.batch_popped_at) - float(task.created_at)) * 1000.0)
batch_collect_wait_ms = max(0.0, (float(batch_collected_at) - float(task.batch_popped_at)) * 1000.0)
stage_limiter_wait_ms = float(limiter_stats["wait_ms"])
task.profile = { task.profile = {
"prompt_semantic_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]), "prompt_semantic_wait_ms": worker_queue_wait_ms
+ batch_collect_wait_ms
+ stage_limiter_wait_ms,
"prompt_semantic_worker_queue_wait_ms": worker_queue_wait_ms,
"prompt_semantic_batch_collect_wait_ms": batch_collect_wait_ms,
"prompt_semantic_stage_limiter_wait_ms": stage_limiter_wait_ms,
"prompt_semantic_batch_dispatch_delay_ms": max(
0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0
),
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
"prompt_semantic_forward_ms": float(forward_ms), "prompt_semantic_forward_ms": float(forward_ms),
"prompt_semantic_scatter_ms": 0.0, "prompt_semantic_scatter_ms": 0.0,
@ -289,9 +305,9 @@ class PrepareRefSemanticBatchWorker:
def _run_loop(self) -> None: def _run_loop(self) -> None:
while True: while True:
batch = self._collect_batch() batch, batch_collected_at = self._collect_batch()
try: try:
self._run_batch(batch) self._run_batch(batch, batch_collected_at)
except Exception as exc: # noqa: PERF203 except Exception as exc: # noqa: PERF203
for task in batch: for task in batch:
task.error = exc task.error = exc

View File

@ -305,6 +305,18 @@ def build_request_state_from_parts(
"audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)), "audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)),
"prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)), "prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_worker_queue_wait_ms": float(
bundle_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
),
"prompt_semantic_batch_collect_wait_ms": float(
bundle_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
),
"prompt_semantic_stage_limiter_wait_ms": float(
bundle_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
),
"prompt_semantic_batch_dispatch_delay_ms": float(
bundle_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
),
"prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), "prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
"prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)), "prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)),
"prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)), "prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)),