mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-11 04:08:11 +08:00
Enhance TTS audio processing with improved resampling and profiling metrics
Refactor the audio preparation workflow to utilize torchaudio for resampling, replacing librosa for better performance. Introduce a caching mechanism for resampling transforms and update the PrepareRefSemanticBatchWorker to include detailed timing metrics for profiling. Additionally, implement a new CPU limiter for managing resource allocation during audio processing. These changes improve the efficiency and maintainability of the TTS system.
This commit is contained in:
parent
bc1f3f32de
commit
c94de2f2cb
@ -454,6 +454,7 @@ class TTS:
|
||||
}
|
||||
self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1")))
|
||||
self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4")))
|
||||
self.prepare_ref_audio_cpu_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_CPU_SLOTS", "8")))
|
||||
self.prepare_bert_batch_worker = None
|
||||
self.prepare_ref_semantic_batch_worker = None
|
||||
self.prepare_text_cpu_worker = None
|
||||
@ -952,15 +953,36 @@ class TTS:
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
return prompt_semantic, forward_ms
|
||||
|
||||
@torch.inference_mode()
|
||||
def _prepare_prompt_semantic_wav16k_profile(self, raw_audio: torch.Tensor, raw_sr: int):
|
||||
limiter = getattr(self, "prepare_ref_audio_cpu_limiter", None)
|
||||
if limiter is None:
|
||||
cpu_prepare_start = time.perf_counter()
|
||||
wav16k = prepare_prompt_semantic_wav16k(
|
||||
raw_audio=raw_audio,
|
||||
raw_sr=raw_sr,
|
||||
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
||||
)
|
||||
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
|
||||
return wav16k, cpu_prepare_ms, {"wait_ms": 0.0, "slots": 0.0, "peak_inflight": 0.0}
|
||||
|
||||
with limiter.enter() as limiter_stats:
|
||||
cpu_prepare_start = time.perf_counter()
|
||||
wav16k = prepare_prompt_semantic_wav16k(
|
||||
raw_audio=raw_audio,
|
||||
raw_sr=raw_sr,
|
||||
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
||||
)
|
||||
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
|
||||
return wav16k, cpu_prepare_ms, {
|
||||
"wait_ms": float(limiter_stats.get("wait_ms", 0.0)),
|
||||
"slots": float(limiter_stats.get("slots", 0.0)),
|
||||
"peak_inflight": float(limiter_stats.get("peak_inflight", 0.0)),
|
||||
}
|
||||
|
||||
@torch.inference_mode()
|
||||
def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
|
||||
cpu_prepare_start = time.perf_counter()
|
||||
wav16k = prepare_prompt_semantic_wav16k(
|
||||
raw_audio=raw_audio,
|
||||
raw_sr=raw_sr,
|
||||
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
||||
)
|
||||
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
|
||||
wav16k, cpu_prepare_ms, _ = self._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr)
|
||||
prompt_semantic, forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k)
|
||||
return prompt_semantic, cpu_prepare_ms, forward_ms
|
||||
|
||||
@ -1016,13 +1038,9 @@ class TTS:
|
||||
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
|
||||
load_ms = (time.perf_counter() - load_start) * 1000.0
|
||||
if self.prepare_ref_semantic_batch_worker is None:
|
||||
prompt_semantic_cpu_prepare_start = time.perf_counter()
|
||||
wav16k = prepare_prompt_semantic_wav16k(
|
||||
raw_audio=raw_audio,
|
||||
raw_sr=raw_sr,
|
||||
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
||||
wav16k, prompt_semantic_cpu_prepare_ms, prompt_semantic_cpu_limiter_stats = (
|
||||
self._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr)
|
||||
)
|
||||
prompt_semantic_cpu_prepare_ms = (time.perf_counter() - prompt_semantic_cpu_prepare_start) * 1000.0
|
||||
with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats:
|
||||
prompt_semantic_start = time.perf_counter()
|
||||
prompt_semantic, prompt_semantic_forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(
|
||||
@ -1037,6 +1055,11 @@ class TTS:
|
||||
audio_stage_inflight_peak = float(limiter_stats["peak_inflight"])
|
||||
prompt_semantic_profile = {
|
||||
"prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]),
|
||||
"prompt_semantic_cpu_prepare_wait_ms": float(prompt_semantic_cpu_limiter_stats.get("wait_ms", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_slots": float(prompt_semantic_cpu_limiter_stats.get("slots", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_inflight_peak": float(
|
||||
prompt_semantic_cpu_limiter_stats.get("peak_inflight", 0.0)
|
||||
),
|
||||
"prompt_semantic_worker_queue_wait_ms": 0.0,
|
||||
"prompt_semantic_batch_collect_wait_ms": 0.0,
|
||||
"prompt_semantic_stage_limiter_wait_ms": float(limiter_stats["wait_ms"]),
|
||||
@ -1062,6 +1085,15 @@ class TTS:
|
||||
"audio_stage_inflight_peak": audio_stage_inflight_peak,
|
||||
"prompt_semantic_ms": prompt_semantic_ms,
|
||||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_wait_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_cpu_prepare_slots": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_slots", 0.0)
|
||||
),
|
||||
"prompt_semantic_cpu_prepare_inflight_peak": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0)
|
||||
),
|
||||
"prompt_semantic_worker_queue_wait_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
|
||||
),
|
||||
@ -1101,6 +1133,9 @@ class TTS:
|
||||
|
||||
prompt_semantic_profile = {
|
||||
"prompt_semantic_wait_ms": 0.0,
|
||||
"prompt_semantic_cpu_prepare_wait_ms": 0.0,
|
||||
"prompt_semantic_cpu_prepare_slots": float(getattr(self.prepare_ref_audio_cpu_limiter, "slots", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_inflight_peak": 0.0,
|
||||
"prompt_semantic_worker_queue_wait_ms": 0.0,
|
||||
"prompt_semantic_batch_collect_wait_ms": 0.0,
|
||||
"prompt_semantic_stage_limiter_wait_ms": 0.0,
|
||||
@ -1148,6 +1183,15 @@ class TTS:
|
||||
"audio_stage_inflight_peak": audio_stage_inflight_peak,
|
||||
"prompt_semantic_ms": prompt_semantic_ms,
|
||||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_wait_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_cpu_prepare_slots": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_slots", 0.0)
|
||||
),
|
||||
"prompt_semantic_cpu_prepare_inflight_peak": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0)
|
||||
),
|
||||
"prompt_semantic_worker_queue_wait_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
|
||||
),
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
@ -6,28 +7,48 @@ from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Deque, Dict, List, Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
|
||||
REF_AUDIO_MIN_SAMPLES_16K = 48000
|
||||
REF_AUDIO_MAX_SAMPLES_16K = 160000
|
||||
_RESAMPLE_CACHE_LOCK = threading.Lock()
|
||||
_RESAMPLE_CACHE: Dict[Tuple[int, int, str], torchaudio.transforms.Resample] = {}
|
||||
|
||||
|
||||
def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.transforms.Resample:
|
||||
device_key = str(device)
|
||||
key = (int(orig_sr), int(target_sr), device_key)
|
||||
with _RESAMPLE_CACHE_LOCK:
|
||||
transform = _RESAMPLE_CACHE.get(key)
|
||||
if transform is None:
|
||||
transform = torchaudio.transforms.Resample(orig_freq=int(orig_sr), new_freq=int(target_sr)).to(device_key)
|
||||
_RESAMPLE_CACHE[key] = transform
|
||||
return transform
|
||||
|
||||
|
||||
def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor:
|
||||
resample_device = os.environ.get("GPTSOVITS_PREPARE_REF_RESAMPLE_DEVICE", "cpu").strip().lower() or "cpu"
|
||||
if resample_device not in {"cpu", "cuda"}:
|
||||
resample_device = "cpu"
|
||||
if resample_device == "cuda" and not torch.cuda.is_available():
|
||||
resample_device = "cpu"
|
||||
wav_mono = raw_audio
|
||||
if wav_mono.dim() == 2 and wav_mono.shape[0] != 1:
|
||||
wav_mono = wav_mono.mean(0, keepdim=True)
|
||||
wav16k = wav_mono.squeeze(0).cpu().numpy()
|
||||
wav16k = wav_mono.to(dtype=torch.float32, device=resample_device)
|
||||
if raw_sr != 16000:
|
||||
wav16k = librosa.resample(wav16k, orig_sr=raw_sr, target_sr=16000)
|
||||
wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k)
|
||||
wav16k = wav16k.squeeze(0).contiguous()
|
||||
if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K:
|
||||
raise OSError("参考音频在3~10秒范围外,请更换!")
|
||||
wav16k = np.ascontiguousarray(wav16k, dtype=np.float32)
|
||||
if zero_wav_samples > 0:
|
||||
wav16k = np.concatenate([wav16k, np.zeros(int(zero_wav_samples), dtype=np.float32)], axis=0)
|
||||
return torch.from_numpy(wav16k)
|
||||
wav16k = torch.cat(
|
||||
[wav16k, torch.zeros(int(zero_wav_samples), dtype=torch.float32, device=wav16k.device)],
|
||||
dim=0,
|
||||
)
|
||||
return wav16k.contiguous()
|
||||
|
||||
|
||||
def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor:
|
||||
|
||||
@ -305,6 +305,11 @@ def build_request_state_from_parts(
|
||||
"audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)),
|
||||
"prompt_semantic_ms": prompt_semantic_ms,
|
||||
"prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_wait_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_slots": float(bundle_profile.get("prompt_semantic_cpu_prepare_slots", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_inflight_peak": float(
|
||||
bundle_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0)
|
||||
),
|
||||
"prompt_semantic_worker_queue_wait_ms": float(
|
||||
bundle_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
|
||||
),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user