Enhance TTS audio processing with improved resampling and profiling metrics

Refactor the audio preparation workflow to utilize torchaudio for resampling, replacing librosa for better performance. Introduce a caching mechanism for resampling transforms and update the PrepareRefSemanticBatchWorker to include detailed timing metrics for profiling. Additionally, implement a new CPU limiter for managing resource allocation during audio processing. These changes improve the efficiency and maintainability of the TTS system.
This commit is contained in:
baicai-1145 2026-03-13 16:45:00 +08:00
parent bc1f3f32de
commit c94de2f2cb
3 changed files with 90 additions and 20 deletions

View File

@ -454,6 +454,7 @@ class TTS:
}
self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1")))
self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4")))
self.prepare_ref_audio_cpu_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_CPU_SLOTS", "8")))
self.prepare_bert_batch_worker = None
self.prepare_ref_semantic_batch_worker = None
self.prepare_text_cpu_worker = None
@ -952,15 +953,36 @@ class TTS:
forward_ms = (time.perf_counter() - forward_start) * 1000.0
return prompt_semantic, forward_ms
@torch.inference_mode()
def _prepare_prompt_semantic_wav16k_profile(self, raw_audio: torch.Tensor, raw_sr: int):
limiter = getattr(self, "prepare_ref_audio_cpu_limiter", None)
if limiter is None:
cpu_prepare_start = time.perf_counter()
wav16k = prepare_prompt_semantic_wav16k(
raw_audio=raw_audio,
raw_sr=raw_sr,
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
)
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
return wav16k, cpu_prepare_ms, {"wait_ms": 0.0, "slots": 0.0, "peak_inflight": 0.0}
with limiter.enter() as limiter_stats:
cpu_prepare_start = time.perf_counter()
wav16k = prepare_prompt_semantic_wav16k(
raw_audio=raw_audio,
raw_sr=raw_sr,
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
)
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
return wav16k, cpu_prepare_ms, {
"wait_ms": float(limiter_stats.get("wait_ms", 0.0)),
"slots": float(limiter_stats.get("slots", 0.0)),
"peak_inflight": float(limiter_stats.get("peak_inflight", 0.0)),
}
@torch.inference_mode()
def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
cpu_prepare_start = time.perf_counter()
wav16k = prepare_prompt_semantic_wav16k(
raw_audio=raw_audio,
raw_sr=raw_sr,
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
)
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
wav16k, cpu_prepare_ms, _ = self._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr)
prompt_semantic, forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k)
return prompt_semantic, cpu_prepare_ms, forward_ms
@ -1016,13 +1038,9 @@ class TTS:
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
load_ms = (time.perf_counter() - load_start) * 1000.0
if self.prepare_ref_semantic_batch_worker is None:
prompt_semantic_cpu_prepare_start = time.perf_counter()
wav16k = prepare_prompt_semantic_wav16k(
raw_audio=raw_audio,
raw_sr=raw_sr,
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
wav16k, prompt_semantic_cpu_prepare_ms, prompt_semantic_cpu_limiter_stats = (
self._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr)
)
prompt_semantic_cpu_prepare_ms = (time.perf_counter() - prompt_semantic_cpu_prepare_start) * 1000.0
with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats:
prompt_semantic_start = time.perf_counter()
prompt_semantic, prompt_semantic_forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(
@ -1037,6 +1055,11 @@ class TTS:
audio_stage_inflight_peak = float(limiter_stats["peak_inflight"])
prompt_semantic_profile = {
"prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]),
"prompt_semantic_cpu_prepare_wait_ms": float(prompt_semantic_cpu_limiter_stats.get("wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_slots": float(prompt_semantic_cpu_limiter_stats.get("slots", 0.0)),
"prompt_semantic_cpu_prepare_inflight_peak": float(
prompt_semantic_cpu_limiter_stats.get("peak_inflight", 0.0)
),
"prompt_semantic_worker_queue_wait_ms": 0.0,
"prompt_semantic_batch_collect_wait_ms": 0.0,
"prompt_semantic_stage_limiter_wait_ms": float(limiter_stats["wait_ms"]),
@ -1062,6 +1085,15 @@ class TTS:
"audio_stage_inflight_peak": audio_stage_inflight_peak,
"prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0)
),
"prompt_semantic_cpu_prepare_slots": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_slots", 0.0)
),
"prompt_semantic_cpu_prepare_inflight_peak": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0)
),
"prompt_semantic_worker_queue_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
),
@ -1101,6 +1133,9 @@ class TTS:
prompt_semantic_profile = {
"prompt_semantic_wait_ms": 0.0,
"prompt_semantic_cpu_prepare_wait_ms": 0.0,
"prompt_semantic_cpu_prepare_slots": float(getattr(self.prepare_ref_audio_cpu_limiter, "slots", 0.0)),
"prompt_semantic_cpu_prepare_inflight_peak": 0.0,
"prompt_semantic_worker_queue_wait_ms": 0.0,
"prompt_semantic_batch_collect_wait_ms": 0.0,
"prompt_semantic_stage_limiter_wait_ms": 0.0,
@ -1148,6 +1183,15 @@ class TTS:
"audio_stage_inflight_peak": audio_stage_inflight_peak,
"prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0)
),
"prompt_semantic_cpu_prepare_slots": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_slots", 0.0)
),
"prompt_semantic_cpu_prepare_inflight_peak": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0)
),
"prompt_semantic_worker_queue_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
),

View File

@ -1,4 +1,5 @@
import asyncio
import os
import threading
import time
import uuid
@ -6,28 +7,48 @@ from collections import deque
from dataclasses import dataclass, field
from typing import Deque, Dict, List, Tuple
import librosa
import numpy as np
import torch
import torchaudio
REF_AUDIO_MIN_SAMPLES_16K = 48000
REF_AUDIO_MAX_SAMPLES_16K = 160000
_RESAMPLE_CACHE_LOCK = threading.Lock()
_RESAMPLE_CACHE: Dict[Tuple[int, int, str], torchaudio.transforms.Resample] = {}
def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.transforms.Resample:
device_key = str(device)
key = (int(orig_sr), int(target_sr), device_key)
with _RESAMPLE_CACHE_LOCK:
transform = _RESAMPLE_CACHE.get(key)
if transform is None:
transform = torchaudio.transforms.Resample(orig_freq=int(orig_sr), new_freq=int(target_sr)).to(device_key)
_RESAMPLE_CACHE[key] = transform
return transform
def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor:
resample_device = os.environ.get("GPTSOVITS_PREPARE_REF_RESAMPLE_DEVICE", "cpu").strip().lower() or "cpu"
if resample_device not in {"cpu", "cuda"}:
resample_device = "cpu"
if resample_device == "cuda" and not torch.cuda.is_available():
resample_device = "cpu"
wav_mono = raw_audio
if wav_mono.dim() == 2 and wav_mono.shape[0] != 1:
wav_mono = wav_mono.mean(0, keepdim=True)
wav16k = wav_mono.squeeze(0).cpu().numpy()
wav16k = wav_mono.to(dtype=torch.float32, device=resample_device)
if raw_sr != 16000:
wav16k = librosa.resample(wav16k, orig_sr=raw_sr, target_sr=16000)
wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k)
wav16k = wav16k.squeeze(0).contiguous()
if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K:
raise OSError("参考音频在3~10秒范围外请更换")
wav16k = np.ascontiguousarray(wav16k, dtype=np.float32)
if zero_wav_samples > 0:
wav16k = np.concatenate([wav16k, np.zeros(int(zero_wav_samples), dtype=np.float32)], axis=0)
return torch.from_numpy(wav16k)
wav16k = torch.cat(
[wav16k, torch.zeros(int(zero_wav_samples), dtype=torch.float32, device=wav16k.device)],
dim=0,
)
return wav16k.contiguous()
def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor:

View File

@ -305,6 +305,11 @@ def build_request_state_from_parts(
"audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)),
"prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_wait_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_slots": float(bundle_profile.get("prompt_semantic_cpu_prepare_slots", 0.0)),
"prompt_semantic_cpu_prepare_inflight_peak": float(
bundle_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0)
),
"prompt_semantic_worker_queue_wait_ms": float(
bundle_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
),