Enhance TTS audio processing with improved resampling and profiling metrics

Refactor the audio preparation workflow to utilize torchaudio for resampling, replacing librosa for better performance. Introduce a caching mechanism for resampling transforms and update the PrepareRefSemanticBatchWorker to include detailed timing metrics for profiling. Additionally, implement a new CPU limiter for managing resource allocation during audio processing. These changes improve the efficiency and maintainability of the TTS system.
This commit is contained in:
baicai-1145 2026-03-13 16:45:00 +08:00
parent bc1f3f32de
commit c94de2f2cb
3 changed files with 90 additions and 20 deletions

View File

@ -454,6 +454,7 @@ class TTS:
} }
self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1"))) self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1")))
self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4"))) self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4")))
self.prepare_ref_audio_cpu_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_CPU_SLOTS", "8")))
self.prepare_bert_batch_worker = None self.prepare_bert_batch_worker = None
self.prepare_ref_semantic_batch_worker = None self.prepare_ref_semantic_batch_worker = None
self.prepare_text_cpu_worker = None self.prepare_text_cpu_worker = None
@ -952,15 +953,36 @@ class TTS:
forward_ms = (time.perf_counter() - forward_start) * 1000.0 forward_ms = (time.perf_counter() - forward_start) * 1000.0
return prompt_semantic, forward_ms return prompt_semantic, forward_ms
@torch.inference_mode()
def _prepare_prompt_semantic_wav16k_profile(self, raw_audio: torch.Tensor, raw_sr: int):
limiter = getattr(self, "prepare_ref_audio_cpu_limiter", None)
if limiter is None:
cpu_prepare_start = time.perf_counter()
wav16k = prepare_prompt_semantic_wav16k(
raw_audio=raw_audio,
raw_sr=raw_sr,
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
)
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
return wav16k, cpu_prepare_ms, {"wait_ms": 0.0, "slots": 0.0, "peak_inflight": 0.0}
with limiter.enter() as limiter_stats:
cpu_prepare_start = time.perf_counter()
wav16k = prepare_prompt_semantic_wav16k(
raw_audio=raw_audio,
raw_sr=raw_sr,
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
)
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
return wav16k, cpu_prepare_ms, {
"wait_ms": float(limiter_stats.get("wait_ms", 0.0)),
"slots": float(limiter_stats.get("slots", 0.0)),
"peak_inflight": float(limiter_stats.get("peak_inflight", 0.0)),
}
@torch.inference_mode() @torch.inference_mode()
def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
cpu_prepare_start = time.perf_counter() wav16k, cpu_prepare_ms, _ = self._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr)
wav16k = prepare_prompt_semantic_wav16k(
raw_audio=raw_audio,
raw_sr=raw_sr,
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
)
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
prompt_semantic, forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k) prompt_semantic, forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k)
return prompt_semantic, cpu_prepare_ms, forward_ms return prompt_semantic, cpu_prepare_ms, forward_ms
@ -1016,13 +1038,9 @@ class TTS:
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
load_ms = (time.perf_counter() - load_start) * 1000.0 load_ms = (time.perf_counter() - load_start) * 1000.0
if self.prepare_ref_semantic_batch_worker is None: if self.prepare_ref_semantic_batch_worker is None:
prompt_semantic_cpu_prepare_start = time.perf_counter() wav16k, prompt_semantic_cpu_prepare_ms, prompt_semantic_cpu_limiter_stats = (
wav16k = prepare_prompt_semantic_wav16k( self._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr)
raw_audio=raw_audio,
raw_sr=raw_sr,
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
) )
prompt_semantic_cpu_prepare_ms = (time.perf_counter() - prompt_semantic_cpu_prepare_start) * 1000.0
with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats: with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats:
prompt_semantic_start = time.perf_counter() prompt_semantic_start = time.perf_counter()
prompt_semantic, prompt_semantic_forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k( prompt_semantic, prompt_semantic_forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(
@ -1037,6 +1055,11 @@ class TTS:
audio_stage_inflight_peak = float(limiter_stats["peak_inflight"]) audio_stage_inflight_peak = float(limiter_stats["peak_inflight"])
prompt_semantic_profile = { prompt_semantic_profile = {
"prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]), "prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]),
"prompt_semantic_cpu_prepare_wait_ms": float(prompt_semantic_cpu_limiter_stats.get("wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_slots": float(prompt_semantic_cpu_limiter_stats.get("slots", 0.0)),
"prompt_semantic_cpu_prepare_inflight_peak": float(
prompt_semantic_cpu_limiter_stats.get("peak_inflight", 0.0)
),
"prompt_semantic_worker_queue_wait_ms": 0.0, "prompt_semantic_worker_queue_wait_ms": 0.0,
"prompt_semantic_batch_collect_wait_ms": 0.0, "prompt_semantic_batch_collect_wait_ms": 0.0,
"prompt_semantic_stage_limiter_wait_ms": float(limiter_stats["wait_ms"]), "prompt_semantic_stage_limiter_wait_ms": float(limiter_stats["wait_ms"]),
@ -1062,6 +1085,15 @@ class TTS:
"audio_stage_inflight_peak": audio_stage_inflight_peak, "audio_stage_inflight_peak": audio_stage_inflight_peak,
"prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0)
),
"prompt_semantic_cpu_prepare_slots": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_slots", 0.0)
),
"prompt_semantic_cpu_prepare_inflight_peak": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0)
),
"prompt_semantic_worker_queue_wait_ms": float( "prompt_semantic_worker_queue_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
), ),
@ -1101,6 +1133,9 @@ class TTS:
prompt_semantic_profile = { prompt_semantic_profile = {
"prompt_semantic_wait_ms": 0.0, "prompt_semantic_wait_ms": 0.0,
"prompt_semantic_cpu_prepare_wait_ms": 0.0,
"prompt_semantic_cpu_prepare_slots": float(getattr(self.prepare_ref_audio_cpu_limiter, "slots", 0.0)),
"prompt_semantic_cpu_prepare_inflight_peak": 0.0,
"prompt_semantic_worker_queue_wait_ms": 0.0, "prompt_semantic_worker_queue_wait_ms": 0.0,
"prompt_semantic_batch_collect_wait_ms": 0.0, "prompt_semantic_batch_collect_wait_ms": 0.0,
"prompt_semantic_stage_limiter_wait_ms": 0.0, "prompt_semantic_stage_limiter_wait_ms": 0.0,
@ -1148,6 +1183,15 @@ class TTS:
"audio_stage_inflight_peak": audio_stage_inflight_peak, "audio_stage_inflight_peak": audio_stage_inflight_peak,
"prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0)
),
"prompt_semantic_cpu_prepare_slots": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_slots", 0.0)
),
"prompt_semantic_cpu_prepare_inflight_peak": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0)
),
"prompt_semantic_worker_queue_wait_ms": float( "prompt_semantic_worker_queue_wait_ms": float(
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
), ),

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import os
import threading import threading
import time import time
import uuid import uuid
@ -6,28 +7,48 @@ from collections import deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Deque, Dict, List, Tuple from typing import Deque, Dict, List, Tuple
import librosa
import numpy as np
import torch import torch
import torchaudio
REF_AUDIO_MIN_SAMPLES_16K = 48000 REF_AUDIO_MIN_SAMPLES_16K = 48000
REF_AUDIO_MAX_SAMPLES_16K = 160000 REF_AUDIO_MAX_SAMPLES_16K = 160000
_RESAMPLE_CACHE_LOCK = threading.Lock()
_RESAMPLE_CACHE: Dict[Tuple[int, int, str], torchaudio.transforms.Resample] = {}
def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.transforms.Resample:
device_key = str(device)
key = (int(orig_sr), int(target_sr), device_key)
with _RESAMPLE_CACHE_LOCK:
transform = _RESAMPLE_CACHE.get(key)
if transform is None:
transform = torchaudio.transforms.Resample(orig_freq=int(orig_sr), new_freq=int(target_sr)).to(device_key)
_RESAMPLE_CACHE[key] = transform
return transform
def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor: def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor:
resample_device = os.environ.get("GPTSOVITS_PREPARE_REF_RESAMPLE_DEVICE", "cpu").strip().lower() or "cpu"
if resample_device not in {"cpu", "cuda"}:
resample_device = "cpu"
if resample_device == "cuda" and not torch.cuda.is_available():
resample_device = "cpu"
wav_mono = raw_audio wav_mono = raw_audio
if wav_mono.dim() == 2 and wav_mono.shape[0] != 1: if wav_mono.dim() == 2 and wav_mono.shape[0] != 1:
wav_mono = wav_mono.mean(0, keepdim=True) wav_mono = wav_mono.mean(0, keepdim=True)
wav16k = wav_mono.squeeze(0).cpu().numpy() wav16k = wav_mono.to(dtype=torch.float32, device=resample_device)
if raw_sr != 16000: if raw_sr != 16000:
wav16k = librosa.resample(wav16k, orig_sr=raw_sr, target_sr=16000) wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k)
wav16k = wav16k.squeeze(0).contiguous()
if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K: if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K:
raise OSError("参考音频在3~10秒范围外请更换") raise OSError("参考音频在3~10秒范围外请更换")
wav16k = np.ascontiguousarray(wav16k, dtype=np.float32)
if zero_wav_samples > 0: if zero_wav_samples > 0:
wav16k = np.concatenate([wav16k, np.zeros(int(zero_wav_samples), dtype=np.float32)], axis=0) wav16k = torch.cat(
return torch.from_numpy(wav16k) [wav16k, torch.zeros(int(zero_wav_samples), dtype=torch.float32, device=wav16k.device)],
dim=0,
)
return wav16k.contiguous()
def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor: def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor:

View File

@ -305,6 +305,11 @@ def build_request_state_from_parts(
"audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)), "audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)),
"prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)), "prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_wait_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_slots": float(bundle_profile.get("prompt_semantic_cpu_prepare_slots", 0.0)),
"prompt_semantic_cpu_prepare_inflight_peak": float(
bundle_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0)
),
"prompt_semantic_worker_queue_wait_ms": float( "prompt_semantic_worker_queue_wait_ms": float(
bundle_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) bundle_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
), ),