diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 0140eff3..78bf7178 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -454,6 +454,7 @@ class TTS: } self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1"))) self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4"))) + self.prepare_ref_audio_cpu_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_CPU_SLOTS", "8"))) self.prepare_bert_batch_worker = None self.prepare_ref_semantic_batch_worker = None self.prepare_text_cpu_worker = None @@ -952,15 +953,36 @@ class TTS: forward_ms = (time.perf_counter() - forward_start) * 1000.0 return prompt_semantic, forward_ms + @torch.inference_mode() + def _prepare_prompt_semantic_wav16k_profile(self, raw_audio: torch.Tensor, raw_sr: int): + limiter = getattr(self, "prepare_ref_audio_cpu_limiter", None) + if limiter is None: + cpu_prepare_start = time.perf_counter() + wav16k = prepare_prompt_semantic_wav16k( + raw_audio=raw_audio, + raw_sr=raw_sr, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + ) + cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0 + return wav16k, cpu_prepare_ms, {"wait_ms": 0.0, "slots": 0.0, "peak_inflight": 0.0} + + with limiter.enter() as limiter_stats: + cpu_prepare_start = time.perf_counter() + wav16k = prepare_prompt_semantic_wav16k( + raw_audio=raw_audio, + raw_sr=raw_sr, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + ) + cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0 + return wav16k, cpu_prepare_ms, { + "wait_ms": float(limiter_stats.get("wait_ms", 0.0)), + "slots": float(limiter_stats.get("slots", 0.0)), + "peak_inflight": float(limiter_stats.get("peak_inflight", 0.0)), + } + @torch.inference_mode() def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): - cpu_prepare_start = time.perf_counter() - wav16k = prepare_prompt_semantic_wav16k( - raw_audio=raw_audio, - raw_sr=raw_sr, - zero_wav_samples=int(self.configs.sampling_rate * 0.3), - ) - cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0 + wav16k, cpu_prepare_ms, _ = self._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr) prompt_semantic, forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k) return prompt_semantic, cpu_prepare_ms, forward_ms @@ -1016,13 +1038,9 @@ class TTS: raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) load_ms = (time.perf_counter() - load_start) * 1000.0 if self.prepare_ref_semantic_batch_worker is None: - prompt_semantic_cpu_prepare_start = time.perf_counter() - wav16k = prepare_prompt_semantic_wav16k( - raw_audio=raw_audio, - raw_sr=raw_sr, - zero_wav_samples=int(self.configs.sampling_rate * 0.3), + wav16k, prompt_semantic_cpu_prepare_ms, prompt_semantic_cpu_limiter_stats = ( + self._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr) ) - prompt_semantic_cpu_prepare_ms = (time.perf_counter() - prompt_semantic_cpu_prepare_start) * 1000.0 with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats: prompt_semantic_start = time.perf_counter() prompt_semantic, prompt_semantic_forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k( @@ -1037,6 +1055,11 @@ class TTS: audio_stage_inflight_peak = float(limiter_stats["peak_inflight"]) prompt_semantic_profile = { "prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]), + "prompt_semantic_cpu_prepare_wait_ms": float(prompt_semantic_cpu_limiter_stats.get("wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_slots": float(prompt_semantic_cpu_limiter_stats.get("slots", 0.0)), + "prompt_semantic_cpu_prepare_inflight_peak": float( + prompt_semantic_cpu_limiter_stats.get("peak_inflight", 0.0) + ), "prompt_semantic_worker_queue_wait_ms": 0.0, "prompt_semantic_batch_collect_wait_ms": 0.0, "prompt_semantic_stage_limiter_wait_ms": float(limiter_stats["wait_ms"]), @@ -1062,6 +1085,15 @@ class TTS: "audio_stage_inflight_peak": audio_stage_inflight_peak, "prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0) + ), + "prompt_semantic_cpu_prepare_slots": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_slots", 0.0) + ), + "prompt_semantic_cpu_prepare_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0) + ), "prompt_semantic_worker_queue_wait_ms": float( prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) ), @@ -1101,6 +1133,9 @@ class TTS: prompt_semantic_profile = { "prompt_semantic_wait_ms": 0.0, + "prompt_semantic_cpu_prepare_wait_ms": 0.0, + "prompt_semantic_cpu_prepare_slots": float(getattr(self.prepare_ref_audio_cpu_limiter, "slots", 0.0)), + "prompt_semantic_cpu_prepare_inflight_peak": 0.0, "prompt_semantic_worker_queue_wait_ms": 0.0, "prompt_semantic_batch_collect_wait_ms": 0.0, "prompt_semantic_stage_limiter_wait_ms": 0.0, @@ -1148,6 +1183,15 @@ class TTS: "audio_stage_inflight_peak": audio_stage_inflight_peak, "prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0) + ), + "prompt_semantic_cpu_prepare_slots": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_slots", 0.0) + ), + "prompt_semantic_cpu_prepare_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0) + ), "prompt_semantic_worker_queue_wait_ms": float( prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) ), diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py index d46352a7..ff5591b2 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py @@ -1,4 +1,5 @@ import asyncio +import os import threading import time import uuid @@ -6,28 +7,48 @@ from collections import deque from dataclasses import dataclass, field from typing import Deque, Dict, List, Tuple -import librosa -import numpy as np import torch +import torchaudio REF_AUDIO_MIN_SAMPLES_16K = 48000 REF_AUDIO_MAX_SAMPLES_16K = 160000 +_RESAMPLE_CACHE_LOCK = threading.Lock() +_RESAMPLE_CACHE: Dict[Tuple[int, int, str], torchaudio.transforms.Resample] = {} + + +def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.transforms.Resample: + device_key = str(device) + key = (int(orig_sr), int(target_sr), device_key) + with _RESAMPLE_CACHE_LOCK: + transform = _RESAMPLE_CACHE.get(key) + if transform is None: + transform = torchaudio.transforms.Resample(orig_freq=int(orig_sr), new_freq=int(target_sr)).to(device_key) + _RESAMPLE_CACHE[key] = transform + return transform def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor: + resample_device = os.environ.get("GPTSOVITS_PREPARE_REF_RESAMPLE_DEVICE", "cpu").strip().lower() or "cpu" + if resample_device not in {"cpu", "cuda"}: + resample_device = "cpu" + if resample_device == "cuda" and not torch.cuda.is_available(): + resample_device = "cpu" wav_mono = raw_audio if wav_mono.dim() == 2 and wav_mono.shape[0] != 1: wav_mono = wav_mono.mean(0, keepdim=True) - wav16k = wav_mono.squeeze(0).cpu().numpy() + wav16k = wav_mono.to(dtype=torch.float32, device=resample_device) if raw_sr != 16000: - wav16k = librosa.resample(wav16k, orig_sr=raw_sr, target_sr=16000) + wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k) + wav16k = wav16k.squeeze(0).contiguous() if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K: raise OSError("参考音频在3~10秒范围外,请更换!") - wav16k = np.ascontiguousarray(wav16k, dtype=np.float32) if zero_wav_samples > 0: - wav16k = np.concatenate([wav16k, np.zeros(int(zero_wav_samples), dtype=np.float32)], axis=0) - return torch.from_numpy(wav16k) + wav16k = torch.cat( + [wav16k, torch.zeros(int(zero_wav_samples), dtype=torch.float32, device=wav16k.device)], + dim=0, + ) + return wav16k.contiguous() def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor: diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index 43290af7..73e2a2c7 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -305,6 +305,11 @@ def build_request_state_from_parts( "audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)), "prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_wait_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_slots": float(bundle_profile.get("prompt_semantic_cpu_prepare_slots", 0.0)), + "prompt_semantic_cpu_prepare_inflight_peak": float( + bundle_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0) + ), "prompt_semantic_worker_queue_wait_ms": float( bundle_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) ),