Enhance audio processing in TTS framework with resampling and profiling improvements

Add resampling capabilities using torchaudio to prepare reference audio at 16kHz, replacing librosa for better performance. Introduce a caching mechanism for resampling transforms to optimize resource usage. Update batch processing methods to include timing metrics for profiling, enhancing the ability to monitor and improve performance in the TTS system. This update improves the maintainability and efficiency of audio preparation workflows.
This commit is contained in:
baicai-1145 2026-03-13 02:03:25 +08:00
parent 17cb2e5acf
commit bc1f3f32de
3 changed files with 97 additions and 13 deletions

View File

@ -945,6 +945,13 @@ class TTS:
codes = self.vits_model.extract_latent(hubert_feature)
return codes[0, 0].to(self.configs.device)
@torch.inference_mode()
def _extract_prompt_semantic_profile_from_prepared_wav16k(self, wav16k: torch.Tensor):
forward_start = time.perf_counter()
prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
return prompt_semantic, forward_ms
@torch.inference_mode()
def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
cpu_prepare_start = time.perf_counter()
@ -954,9 +961,7 @@ class TTS:
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
)
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
forward_start = time.perf_counter()
prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
prompt_semantic, forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k)
return prompt_semantic, cpu_prepare_ms, forward_ms
@torch.inference_mode()
@ -1011,10 +1016,17 @@ class TTS:
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
load_ms = (time.perf_counter() - load_start) * 1000.0
if self.prepare_ref_semantic_batch_worker is None:
prompt_semantic_cpu_prepare_start = time.perf_counter()
wav16k = prepare_prompt_semantic_wav16k(
raw_audio=raw_audio,
raw_sr=raw_sr,
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
)
prompt_semantic_cpu_prepare_ms = (time.perf_counter() - prompt_semantic_cpu_prepare_start) * 1000.0
with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats:
prompt_semantic_start = time.perf_counter()
prompt_semantic, prompt_semantic_cpu_prepare_ms, prompt_semantic_forward_ms = (
self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr)
prompt_semantic, prompt_semantic_forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(
wav16k
)
prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0
ref_spec_start = time.perf_counter()
@ -1025,6 +1037,10 @@ class TTS:
audio_stage_inflight_peak = float(limiter_stats["peak_inflight"])
prompt_semantic_profile = {
"prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]),
"prompt_semantic_worker_queue_wait_ms": 0.0,
"prompt_semantic_batch_collect_wait_ms": 0.0,
"prompt_semantic_stage_limiter_wait_ms": float(limiter_stats["wait_ms"]),
"prompt_semantic_batch_dispatch_delay_ms": 0.0,
"prompt_semantic_cpu_prepare_ms": float(prompt_semantic_cpu_prepare_ms),
"prompt_semantic_forward_ms": float(prompt_semantic_forward_ms),
"prompt_semantic_scatter_ms": 0.0,
@ -1046,6 +1062,18 @@ class TTS:
"audio_stage_inflight_peak": audio_stage_inflight_peak,
"prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_worker_queue_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
),
"prompt_semantic_batch_collect_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
),
"prompt_semantic_stage_limiter_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
),
"prompt_semantic_batch_dispatch_delay_ms": float(
prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
),
"prompt_semantic_cpu_prepare_ms": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
),
@ -1073,6 +1101,10 @@ class TTS:
prompt_semantic_profile = {
"prompt_semantic_wait_ms": 0.0,
"prompt_semantic_worker_queue_wait_ms": 0.0,
"prompt_semantic_batch_collect_wait_ms": 0.0,
"prompt_semantic_stage_limiter_wait_ms": 0.0,
"prompt_semantic_batch_dispatch_delay_ms": 0.0,
"prompt_semantic_cpu_prepare_ms": 0.0,
"prompt_semantic_forward_ms": 0.0,
"prompt_semantic_scatter_ms": 0.0,
@ -1116,6 +1148,18 @@ class TTS:
"audio_stage_inflight_peak": audio_stage_inflight_peak,
"prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_worker_queue_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
),
"prompt_semantic_batch_collect_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
),
"prompt_semantic_stage_limiter_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
),
"prompt_semantic_batch_dispatch_delay_ms": float(
prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
),
"prompt_semantic_cpu_prepare_ms": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
),
@ -1193,6 +1237,18 @@ class TTS:
"audio_stage_inflight_peak": float(audio_stage_inflight_peak),
"prompt_semantic_ms": float(prompt_semantic_ms),
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_worker_queue_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
),
"prompt_semantic_batch_collect_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
),
"prompt_semantic_stage_limiter_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
),
"prompt_semantic_batch_dispatch_delay_ms": float(
prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
),
"prompt_semantic_cpu_prepare_ms": float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),

View File

@ -51,6 +51,7 @@ class RefSemanticTask:
raw_sr: int
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
created_at: float = field(default_factory=time.perf_counter)
batch_popped_at: float = 0.0
done_event: threading.Event = field(default_factory=threading.Event)
done_loop: asyncio.AbstractEventLoop | None = None
done_future: asyncio.Future | None = None
@ -170,12 +171,14 @@ class PrepareRefSemanticBatchWorker:
"max_batch_samples": self.max_batch_samples,
}
def _collect_batch(self) -> List[RefSemanticTask]:
def _collect_batch(self) -> tuple[List[RefSemanticTask], float]:
with self.condition:
while not self.pending_tasks:
self.condition.wait()
batch: List[RefSemanticTask] = [self.pending_tasks.popleft()]
first_task = self.pending_tasks.popleft()
first_task.batch_popped_at = time.perf_counter()
batch: List[RefSemanticTask] = [first_task]
batch_samples = self._estimate_task_samples(batch[0])
deadline = time.perf_counter() + self.batch_window_s
@ -190,7 +193,9 @@ class PrepareRefSemanticBatchWorker:
next_samples = self._estimate_task_samples(next_task)
if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples:
break
batch.append(self.pending_tasks.popleft())
popped_task = self.pending_tasks.popleft()
popped_task.batch_popped_at = time.perf_counter()
batch.append(popped_task)
batch_samples += next_samples
self.active_batch_size = len(batch)
@ -199,7 +204,7 @@ class PrepareRefSemanticBatchWorker:
self.active_batch_peak = self.active_batch_size
if self.active_batch_samples > self.active_batch_samples_peak:
self.active_batch_samples_peak = self.active_batch_samples
return batch
return batch, time.perf_counter()
def _finalize_batch(self, batch: List[RefSemanticTask]) -> None:
with self.condition:
@ -219,7 +224,7 @@ class PrepareRefSemanticBatchWorker:
return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device)
@torch.inference_mode()
def _run_batch(self, batch: List[RefSemanticTask]) -> None:
def _run_batch(self, batch: List[RefSemanticTask], batch_collected_at: float) -> None:
batch_started = time.perf_counter()
prepared_start = time.perf_counter()
prepared_wavs = [
@ -268,8 +273,19 @@ class PrepareRefSemanticBatchWorker:
try:
code_len = int(code_lengths[batch_index].item())
task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone()
worker_queue_wait_ms = max(0.0, (float(task.batch_popped_at) - float(task.created_at)) * 1000.0)
batch_collect_wait_ms = max(0.0, (float(batch_collected_at) - float(task.batch_popped_at)) * 1000.0)
stage_limiter_wait_ms = float(limiter_stats["wait_ms"])
task.profile = {
"prompt_semantic_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]),
"prompt_semantic_wait_ms": worker_queue_wait_ms
+ batch_collect_wait_ms
+ stage_limiter_wait_ms,
"prompt_semantic_worker_queue_wait_ms": worker_queue_wait_ms,
"prompt_semantic_batch_collect_wait_ms": batch_collect_wait_ms,
"prompt_semantic_stage_limiter_wait_ms": stage_limiter_wait_ms,
"prompt_semantic_batch_dispatch_delay_ms": max(
0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0
),
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
"prompt_semantic_forward_ms": float(forward_ms),
"prompt_semantic_scatter_ms": 0.0,
@ -289,9 +305,9 @@ class PrepareRefSemanticBatchWorker:
def _run_loop(self) -> None:
while True:
batch = self._collect_batch()
batch, batch_collected_at = self._collect_batch()
try:
self._run_batch(batch)
self._run_batch(batch, batch_collected_at)
except Exception as exc: # noqa: PERF203
for task in batch:
task.error = exc

View File

@ -305,6 +305,18 @@ def build_request_state_from_parts(
"audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)),
"prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_worker_queue_wait_ms": float(
bundle_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
),
"prompt_semantic_batch_collect_wait_ms": float(
bundle_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
),
"prompt_semantic_stage_limiter_wait_ms": float(
bundle_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
),
"prompt_semantic_batch_dispatch_delay_ms": float(
bundle_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
),
"prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
"prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)),
"prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)),