mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-24 04:48:16 +08:00
Enhance audio processing in TTS framework with resampling and profiling improvements
Add resampling capabilities using torchaudio to prepare reference audio at 16kHz, replacing librosa for better performance. Introduce a caching mechanism for resampling transforms to optimize resource usage. Update batch processing methods to include timing metrics for profiling, enhancing the ability to monitor and improve performance in the TTS system. This update improves the maintainability and efficiency of audio preparation workflows.
This commit is contained in:
parent
17cb2e5acf
commit
bc1f3f32de
@ -945,6 +945,13 @@ class TTS:
|
||||
codes = self.vits_model.extract_latent(hubert_feature)
|
||||
return codes[0, 0].to(self.configs.device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _extract_prompt_semantic_profile_from_prepared_wav16k(self, wav16k: torch.Tensor):
|
||||
forward_start = time.perf_counter()
|
||||
prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
return prompt_semantic, forward_ms
|
||||
|
||||
@torch.inference_mode()
|
||||
def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
|
||||
cpu_prepare_start = time.perf_counter()
|
||||
@ -954,9 +961,7 @@ class TTS:
|
||||
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
||||
)
|
||||
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
|
||||
forward_start = time.perf_counter()
|
||||
prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
prompt_semantic, forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k)
|
||||
return prompt_semantic, cpu_prepare_ms, forward_ms
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -1011,10 +1016,17 @@ class TTS:
|
||||
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
|
||||
load_ms = (time.perf_counter() - load_start) * 1000.0
|
||||
if self.prepare_ref_semantic_batch_worker is None:
|
||||
prompt_semantic_cpu_prepare_start = time.perf_counter()
|
||||
wav16k = prepare_prompt_semantic_wav16k(
|
||||
raw_audio=raw_audio,
|
||||
raw_sr=raw_sr,
|
||||
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
||||
)
|
||||
prompt_semantic_cpu_prepare_ms = (time.perf_counter() - prompt_semantic_cpu_prepare_start) * 1000.0
|
||||
with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats:
|
||||
prompt_semantic_start = time.perf_counter()
|
||||
prompt_semantic, prompt_semantic_cpu_prepare_ms, prompt_semantic_forward_ms = (
|
||||
self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr)
|
||||
prompt_semantic, prompt_semantic_forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(
|
||||
wav16k
|
||||
)
|
||||
prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0
|
||||
ref_spec_start = time.perf_counter()
|
||||
@ -1025,6 +1037,10 @@ class TTS:
|
||||
audio_stage_inflight_peak = float(limiter_stats["peak_inflight"])
|
||||
prompt_semantic_profile = {
|
||||
"prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]),
|
||||
"prompt_semantic_worker_queue_wait_ms": 0.0,
|
||||
"prompt_semantic_batch_collect_wait_ms": 0.0,
|
||||
"prompt_semantic_stage_limiter_wait_ms": float(limiter_stats["wait_ms"]),
|
||||
"prompt_semantic_batch_dispatch_delay_ms": 0.0,
|
||||
"prompt_semantic_cpu_prepare_ms": float(prompt_semantic_cpu_prepare_ms),
|
||||
"prompt_semantic_forward_ms": float(prompt_semantic_forward_ms),
|
||||
"prompt_semantic_scatter_ms": 0.0,
|
||||
@ -1046,6 +1062,18 @@ class TTS:
|
||||
"audio_stage_inflight_peak": audio_stage_inflight_peak,
|
||||
"prompt_semantic_ms": prompt_semantic_ms,
|
||||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"prompt_semantic_worker_queue_wait_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_batch_collect_wait_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_stage_limiter_wait_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_batch_dispatch_delay_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_cpu_prepare_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
|
||||
),
|
||||
@ -1073,6 +1101,10 @@ class TTS:
|
||||
|
||||
prompt_semantic_profile = {
|
||||
"prompt_semantic_wait_ms": 0.0,
|
||||
"prompt_semantic_worker_queue_wait_ms": 0.0,
|
||||
"prompt_semantic_batch_collect_wait_ms": 0.0,
|
||||
"prompt_semantic_stage_limiter_wait_ms": 0.0,
|
||||
"prompt_semantic_batch_dispatch_delay_ms": 0.0,
|
||||
"prompt_semantic_cpu_prepare_ms": 0.0,
|
||||
"prompt_semantic_forward_ms": 0.0,
|
||||
"prompt_semantic_scatter_ms": 0.0,
|
||||
@ -1116,6 +1148,18 @@ class TTS:
|
||||
"audio_stage_inflight_peak": audio_stage_inflight_peak,
|
||||
"prompt_semantic_ms": prompt_semantic_ms,
|
||||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"prompt_semantic_worker_queue_wait_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_batch_collect_wait_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_stage_limiter_wait_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_batch_dispatch_delay_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_cpu_prepare_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
|
||||
),
|
||||
@ -1193,6 +1237,18 @@ class TTS:
|
||||
"audio_stage_inflight_peak": float(audio_stage_inflight_peak),
|
||||
"prompt_semantic_ms": float(prompt_semantic_ms),
|
||||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"prompt_semantic_worker_queue_wait_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_batch_collect_wait_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_stage_limiter_wait_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_batch_dispatch_delay_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_cpu_prepare_ms": float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
|
||||
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
|
||||
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
||||
|
||||
@ -51,6 +51,7 @@ class RefSemanticTask:
|
||||
raw_sr: int
|
||||
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
|
||||
created_at: float = field(default_factory=time.perf_counter)
|
||||
batch_popped_at: float = 0.0
|
||||
done_event: threading.Event = field(default_factory=threading.Event)
|
||||
done_loop: asyncio.AbstractEventLoop | None = None
|
||||
done_future: asyncio.Future | None = None
|
||||
@ -170,12 +171,14 @@ class PrepareRefSemanticBatchWorker:
|
||||
"max_batch_samples": self.max_batch_samples,
|
||||
}
|
||||
|
||||
def _collect_batch(self) -> List[RefSemanticTask]:
|
||||
def _collect_batch(self) -> tuple[List[RefSemanticTask], float]:
|
||||
with self.condition:
|
||||
while not self.pending_tasks:
|
||||
self.condition.wait()
|
||||
|
||||
batch: List[RefSemanticTask] = [self.pending_tasks.popleft()]
|
||||
first_task = self.pending_tasks.popleft()
|
||||
first_task.batch_popped_at = time.perf_counter()
|
||||
batch: List[RefSemanticTask] = [first_task]
|
||||
batch_samples = self._estimate_task_samples(batch[0])
|
||||
deadline = time.perf_counter() + self.batch_window_s
|
||||
|
||||
@ -190,7 +193,9 @@ class PrepareRefSemanticBatchWorker:
|
||||
next_samples = self._estimate_task_samples(next_task)
|
||||
if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples:
|
||||
break
|
||||
batch.append(self.pending_tasks.popleft())
|
||||
popped_task = self.pending_tasks.popleft()
|
||||
popped_task.batch_popped_at = time.perf_counter()
|
||||
batch.append(popped_task)
|
||||
batch_samples += next_samples
|
||||
|
||||
self.active_batch_size = len(batch)
|
||||
@ -199,7 +204,7 @@ class PrepareRefSemanticBatchWorker:
|
||||
self.active_batch_peak = self.active_batch_size
|
||||
if self.active_batch_samples > self.active_batch_samples_peak:
|
||||
self.active_batch_samples_peak = self.active_batch_samples
|
||||
return batch
|
||||
return batch, time.perf_counter()
|
||||
|
||||
def _finalize_batch(self, batch: List[RefSemanticTask]) -> None:
|
||||
with self.condition:
|
||||
@ -219,7 +224,7 @@ class PrepareRefSemanticBatchWorker:
|
||||
return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _run_batch(self, batch: List[RefSemanticTask]) -> None:
|
||||
def _run_batch(self, batch: List[RefSemanticTask], batch_collected_at: float) -> None:
|
||||
batch_started = time.perf_counter()
|
||||
prepared_start = time.perf_counter()
|
||||
prepared_wavs = [
|
||||
@ -268,8 +273,19 @@ class PrepareRefSemanticBatchWorker:
|
||||
try:
|
||||
code_len = int(code_lengths[batch_index].item())
|
||||
task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone()
|
||||
worker_queue_wait_ms = max(0.0, (float(task.batch_popped_at) - float(task.created_at)) * 1000.0)
|
||||
batch_collect_wait_ms = max(0.0, (float(batch_collected_at) - float(task.batch_popped_at)) * 1000.0)
|
||||
stage_limiter_wait_ms = float(limiter_stats["wait_ms"])
|
||||
task.profile = {
|
||||
"prompt_semantic_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]),
|
||||
"prompt_semantic_wait_ms": worker_queue_wait_ms
|
||||
+ batch_collect_wait_ms
|
||||
+ stage_limiter_wait_ms,
|
||||
"prompt_semantic_worker_queue_wait_ms": worker_queue_wait_ms,
|
||||
"prompt_semantic_batch_collect_wait_ms": batch_collect_wait_ms,
|
||||
"prompt_semantic_stage_limiter_wait_ms": stage_limiter_wait_ms,
|
||||
"prompt_semantic_batch_dispatch_delay_ms": max(
|
||||
0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0
|
||||
),
|
||||
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
|
||||
"prompt_semantic_forward_ms": float(forward_ms),
|
||||
"prompt_semantic_scatter_ms": 0.0,
|
||||
@ -289,9 +305,9 @@ class PrepareRefSemanticBatchWorker:
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
while True:
|
||||
batch = self._collect_batch()
|
||||
batch, batch_collected_at = self._collect_batch()
|
||||
try:
|
||||
self._run_batch(batch)
|
||||
self._run_batch(batch, batch_collected_at)
|
||||
except Exception as exc: # noqa: PERF203
|
||||
for task in batch:
|
||||
task.error = exc
|
||||
|
||||
@ -305,6 +305,18 @@ def build_request_state_from_parts(
|
||||
"audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)),
|
||||
"prompt_semantic_ms": prompt_semantic_ms,
|
||||
"prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"prompt_semantic_worker_queue_wait_ms": float(
|
||||
bundle_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_batch_collect_wait_ms": float(
|
||||
bundle_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_stage_limiter_wait_ms": float(
|
||||
bundle_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_batch_dispatch_delay_ms": float(
|
||||
bundle_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
|
||||
"prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)),
|
||||
"prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user