mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-03 12:18:12 +08:00
Enhance audio processing in TTS framework with resampling and profiling improvements
Add resampling capabilities using torchaudio to prepare reference audio at 16kHz, replacing librosa for better performance. Introduce a caching mechanism for resampling transforms to optimize resource usage. Update batch processing methods to include timing metrics for profiling, enhancing the ability to monitor and improve performance in the TTS system. This update improves the maintainability and efficiency of audio preparation workflows.
This commit is contained in:
parent
17cb2e5acf
commit
bc1f3f32de
@ -945,6 +945,13 @@ class TTS:
|
|||||||
codes = self.vits_model.extract_latent(hubert_feature)
|
codes = self.vits_model.extract_latent(hubert_feature)
|
||||||
return codes[0, 0].to(self.configs.device)
|
return codes[0, 0].to(self.configs.device)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _extract_prompt_semantic_profile_from_prepared_wav16k(self, wav16k: torch.Tensor):
|
||||||
|
forward_start = time.perf_counter()
|
||||||
|
prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
|
||||||
|
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||||
|
return prompt_semantic, forward_ms
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
|
def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
|
||||||
cpu_prepare_start = time.perf_counter()
|
cpu_prepare_start = time.perf_counter()
|
||||||
@ -954,9 +961,7 @@ class TTS:
|
|||||||
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
||||||
)
|
)
|
||||||
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
|
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
|
||||||
forward_start = time.perf_counter()
|
prompt_semantic, forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k)
|
||||||
prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
|
|
||||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
|
||||||
return prompt_semantic, cpu_prepare_ms, forward_ms
|
return prompt_semantic, cpu_prepare_ms, forward_ms
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -1011,10 +1016,17 @@ class TTS:
|
|||||||
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
|
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
|
||||||
load_ms = (time.perf_counter() - load_start) * 1000.0
|
load_ms = (time.perf_counter() - load_start) * 1000.0
|
||||||
if self.prepare_ref_semantic_batch_worker is None:
|
if self.prepare_ref_semantic_batch_worker is None:
|
||||||
|
prompt_semantic_cpu_prepare_start = time.perf_counter()
|
||||||
|
wav16k = prepare_prompt_semantic_wav16k(
|
||||||
|
raw_audio=raw_audio,
|
||||||
|
raw_sr=raw_sr,
|
||||||
|
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
||||||
|
)
|
||||||
|
prompt_semantic_cpu_prepare_ms = (time.perf_counter() - prompt_semantic_cpu_prepare_start) * 1000.0
|
||||||
with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats:
|
with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats:
|
||||||
prompt_semantic_start = time.perf_counter()
|
prompt_semantic_start = time.perf_counter()
|
||||||
prompt_semantic, prompt_semantic_cpu_prepare_ms, prompt_semantic_forward_ms = (
|
prompt_semantic, prompt_semantic_forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(
|
||||||
self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr)
|
wav16k
|
||||||
)
|
)
|
||||||
prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0
|
prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0
|
||||||
ref_spec_start = time.perf_counter()
|
ref_spec_start = time.perf_counter()
|
||||||
@ -1025,6 +1037,10 @@ class TTS:
|
|||||||
audio_stage_inflight_peak = float(limiter_stats["peak_inflight"])
|
audio_stage_inflight_peak = float(limiter_stats["peak_inflight"])
|
||||||
prompt_semantic_profile = {
|
prompt_semantic_profile = {
|
||||||
"prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]),
|
"prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]),
|
||||||
|
"prompt_semantic_worker_queue_wait_ms": 0.0,
|
||||||
|
"prompt_semantic_batch_collect_wait_ms": 0.0,
|
||||||
|
"prompt_semantic_stage_limiter_wait_ms": float(limiter_stats["wait_ms"]),
|
||||||
|
"prompt_semantic_batch_dispatch_delay_ms": 0.0,
|
||||||
"prompt_semantic_cpu_prepare_ms": float(prompt_semantic_cpu_prepare_ms),
|
"prompt_semantic_cpu_prepare_ms": float(prompt_semantic_cpu_prepare_ms),
|
||||||
"prompt_semantic_forward_ms": float(prompt_semantic_forward_ms),
|
"prompt_semantic_forward_ms": float(prompt_semantic_forward_ms),
|
||||||
"prompt_semantic_scatter_ms": 0.0,
|
"prompt_semantic_scatter_ms": 0.0,
|
||||||
@ -1046,6 +1062,18 @@ class TTS:
|
|||||||
"audio_stage_inflight_peak": audio_stage_inflight_peak,
|
"audio_stage_inflight_peak": audio_stage_inflight_peak,
|
||||||
"prompt_semantic_ms": prompt_semantic_ms,
|
"prompt_semantic_ms": prompt_semantic_ms,
|
||||||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||||
|
"prompt_semantic_worker_queue_wait_ms": float(
|
||||||
|
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
|
||||||
|
),
|
||||||
|
"prompt_semantic_batch_collect_wait_ms": float(
|
||||||
|
prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
|
||||||
|
),
|
||||||
|
"prompt_semantic_stage_limiter_wait_ms": float(
|
||||||
|
prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
|
||||||
|
),
|
||||||
|
"prompt_semantic_batch_dispatch_delay_ms": float(
|
||||||
|
prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
|
||||||
|
),
|
||||||
"prompt_semantic_cpu_prepare_ms": float(
|
"prompt_semantic_cpu_prepare_ms": float(
|
||||||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
|
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
|
||||||
),
|
),
|
||||||
@ -1073,6 +1101,10 @@ class TTS:
|
|||||||
|
|
||||||
prompt_semantic_profile = {
|
prompt_semantic_profile = {
|
||||||
"prompt_semantic_wait_ms": 0.0,
|
"prompt_semantic_wait_ms": 0.0,
|
||||||
|
"prompt_semantic_worker_queue_wait_ms": 0.0,
|
||||||
|
"prompt_semantic_batch_collect_wait_ms": 0.0,
|
||||||
|
"prompt_semantic_stage_limiter_wait_ms": 0.0,
|
||||||
|
"prompt_semantic_batch_dispatch_delay_ms": 0.0,
|
||||||
"prompt_semantic_cpu_prepare_ms": 0.0,
|
"prompt_semantic_cpu_prepare_ms": 0.0,
|
||||||
"prompt_semantic_forward_ms": 0.0,
|
"prompt_semantic_forward_ms": 0.0,
|
||||||
"prompt_semantic_scatter_ms": 0.0,
|
"prompt_semantic_scatter_ms": 0.0,
|
||||||
@ -1116,6 +1148,18 @@ class TTS:
|
|||||||
"audio_stage_inflight_peak": audio_stage_inflight_peak,
|
"audio_stage_inflight_peak": audio_stage_inflight_peak,
|
||||||
"prompt_semantic_ms": prompt_semantic_ms,
|
"prompt_semantic_ms": prompt_semantic_ms,
|
||||||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||||
|
"prompt_semantic_worker_queue_wait_ms": float(
|
||||||
|
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
|
||||||
|
),
|
||||||
|
"prompt_semantic_batch_collect_wait_ms": float(
|
||||||
|
prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
|
||||||
|
),
|
||||||
|
"prompt_semantic_stage_limiter_wait_ms": float(
|
||||||
|
prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
|
||||||
|
),
|
||||||
|
"prompt_semantic_batch_dispatch_delay_ms": float(
|
||||||
|
prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
|
||||||
|
),
|
||||||
"prompt_semantic_cpu_prepare_ms": float(
|
"prompt_semantic_cpu_prepare_ms": float(
|
||||||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
|
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
|
||||||
),
|
),
|
||||||
@ -1193,6 +1237,18 @@ class TTS:
|
|||||||
"audio_stage_inflight_peak": float(audio_stage_inflight_peak),
|
"audio_stage_inflight_peak": float(audio_stage_inflight_peak),
|
||||||
"prompt_semantic_ms": float(prompt_semantic_ms),
|
"prompt_semantic_ms": float(prompt_semantic_ms),
|
||||||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||||
|
"prompt_semantic_worker_queue_wait_ms": float(
|
||||||
|
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
|
||||||
|
),
|
||||||
|
"prompt_semantic_batch_collect_wait_ms": float(
|
||||||
|
prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
|
||||||
|
),
|
||||||
|
"prompt_semantic_stage_limiter_wait_ms": float(
|
||||||
|
prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
|
||||||
|
),
|
||||||
|
"prompt_semantic_batch_dispatch_delay_ms": float(
|
||||||
|
prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
|
||||||
|
),
|
||||||
"prompt_semantic_cpu_prepare_ms": float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
|
"prompt_semantic_cpu_prepare_ms": float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
|
||||||
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
|
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
|
||||||
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
||||||
|
|||||||
@ -51,6 +51,7 @@ class RefSemanticTask:
|
|||||||
raw_sr: int
|
raw_sr: int
|
||||||
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
|
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
|
||||||
created_at: float = field(default_factory=time.perf_counter)
|
created_at: float = field(default_factory=time.perf_counter)
|
||||||
|
batch_popped_at: float = 0.0
|
||||||
done_event: threading.Event = field(default_factory=threading.Event)
|
done_event: threading.Event = field(default_factory=threading.Event)
|
||||||
done_loop: asyncio.AbstractEventLoop | None = None
|
done_loop: asyncio.AbstractEventLoop | None = None
|
||||||
done_future: asyncio.Future | None = None
|
done_future: asyncio.Future | None = None
|
||||||
@ -170,12 +171,14 @@ class PrepareRefSemanticBatchWorker:
|
|||||||
"max_batch_samples": self.max_batch_samples,
|
"max_batch_samples": self.max_batch_samples,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _collect_batch(self) -> List[RefSemanticTask]:
|
def _collect_batch(self) -> tuple[List[RefSemanticTask], float]:
|
||||||
with self.condition:
|
with self.condition:
|
||||||
while not self.pending_tasks:
|
while not self.pending_tasks:
|
||||||
self.condition.wait()
|
self.condition.wait()
|
||||||
|
|
||||||
batch: List[RefSemanticTask] = [self.pending_tasks.popleft()]
|
first_task = self.pending_tasks.popleft()
|
||||||
|
first_task.batch_popped_at = time.perf_counter()
|
||||||
|
batch: List[RefSemanticTask] = [first_task]
|
||||||
batch_samples = self._estimate_task_samples(batch[0])
|
batch_samples = self._estimate_task_samples(batch[0])
|
||||||
deadline = time.perf_counter() + self.batch_window_s
|
deadline = time.perf_counter() + self.batch_window_s
|
||||||
|
|
||||||
@ -190,7 +193,9 @@ class PrepareRefSemanticBatchWorker:
|
|||||||
next_samples = self._estimate_task_samples(next_task)
|
next_samples = self._estimate_task_samples(next_task)
|
||||||
if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples:
|
if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples:
|
||||||
break
|
break
|
||||||
batch.append(self.pending_tasks.popleft())
|
popped_task = self.pending_tasks.popleft()
|
||||||
|
popped_task.batch_popped_at = time.perf_counter()
|
||||||
|
batch.append(popped_task)
|
||||||
batch_samples += next_samples
|
batch_samples += next_samples
|
||||||
|
|
||||||
self.active_batch_size = len(batch)
|
self.active_batch_size = len(batch)
|
||||||
@ -199,7 +204,7 @@ class PrepareRefSemanticBatchWorker:
|
|||||||
self.active_batch_peak = self.active_batch_size
|
self.active_batch_peak = self.active_batch_size
|
||||||
if self.active_batch_samples > self.active_batch_samples_peak:
|
if self.active_batch_samples > self.active_batch_samples_peak:
|
||||||
self.active_batch_samples_peak = self.active_batch_samples
|
self.active_batch_samples_peak = self.active_batch_samples
|
||||||
return batch
|
return batch, time.perf_counter()
|
||||||
|
|
||||||
def _finalize_batch(self, batch: List[RefSemanticTask]) -> None:
|
def _finalize_batch(self, batch: List[RefSemanticTask]) -> None:
|
||||||
with self.condition:
|
with self.condition:
|
||||||
@ -219,7 +224,7 @@ class PrepareRefSemanticBatchWorker:
|
|||||||
return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device)
|
return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _run_batch(self, batch: List[RefSemanticTask]) -> None:
|
def _run_batch(self, batch: List[RefSemanticTask], batch_collected_at: float) -> None:
|
||||||
batch_started = time.perf_counter()
|
batch_started = time.perf_counter()
|
||||||
prepared_start = time.perf_counter()
|
prepared_start = time.perf_counter()
|
||||||
prepared_wavs = [
|
prepared_wavs = [
|
||||||
@ -268,8 +273,19 @@ class PrepareRefSemanticBatchWorker:
|
|||||||
try:
|
try:
|
||||||
code_len = int(code_lengths[batch_index].item())
|
code_len = int(code_lengths[batch_index].item())
|
||||||
task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone()
|
task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone()
|
||||||
|
worker_queue_wait_ms = max(0.0, (float(task.batch_popped_at) - float(task.created_at)) * 1000.0)
|
||||||
|
batch_collect_wait_ms = max(0.0, (float(batch_collected_at) - float(task.batch_popped_at)) * 1000.0)
|
||||||
|
stage_limiter_wait_ms = float(limiter_stats["wait_ms"])
|
||||||
task.profile = {
|
task.profile = {
|
||||||
"prompt_semantic_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]),
|
"prompt_semantic_wait_ms": worker_queue_wait_ms
|
||||||
|
+ batch_collect_wait_ms
|
||||||
|
+ stage_limiter_wait_ms,
|
||||||
|
"prompt_semantic_worker_queue_wait_ms": worker_queue_wait_ms,
|
||||||
|
"prompt_semantic_batch_collect_wait_ms": batch_collect_wait_ms,
|
||||||
|
"prompt_semantic_stage_limiter_wait_ms": stage_limiter_wait_ms,
|
||||||
|
"prompt_semantic_batch_dispatch_delay_ms": max(
|
||||||
|
0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0
|
||||||
|
),
|
||||||
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
|
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
|
||||||
"prompt_semantic_forward_ms": float(forward_ms),
|
"prompt_semantic_forward_ms": float(forward_ms),
|
||||||
"prompt_semantic_scatter_ms": 0.0,
|
"prompt_semantic_scatter_ms": 0.0,
|
||||||
@ -289,9 +305,9 @@ class PrepareRefSemanticBatchWorker:
|
|||||||
|
|
||||||
def _run_loop(self) -> None:
|
def _run_loop(self) -> None:
|
||||||
while True:
|
while True:
|
||||||
batch = self._collect_batch()
|
batch, batch_collected_at = self._collect_batch()
|
||||||
try:
|
try:
|
||||||
self._run_batch(batch)
|
self._run_batch(batch, batch_collected_at)
|
||||||
except Exception as exc: # noqa: PERF203
|
except Exception as exc: # noqa: PERF203
|
||||||
for task in batch:
|
for task in batch:
|
||||||
task.error = exc
|
task.error = exc
|
||||||
|
|||||||
@ -305,6 +305,18 @@ def build_request_state_from_parts(
|
|||||||
"audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)),
|
"audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)),
|
||||||
"prompt_semantic_ms": prompt_semantic_ms,
|
"prompt_semantic_ms": prompt_semantic_ms,
|
||||||
"prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)),
|
"prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||||
|
"prompt_semantic_worker_queue_wait_ms": float(
|
||||||
|
bundle_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
|
||||||
|
),
|
||||||
|
"prompt_semantic_batch_collect_wait_ms": float(
|
||||||
|
bundle_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
|
||||||
|
),
|
||||||
|
"prompt_semantic_stage_limiter_wait_ms": float(
|
||||||
|
bundle_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
|
||||||
|
),
|
||||||
|
"prompt_semantic_batch_dispatch_delay_ms": float(
|
||||||
|
bundle_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
|
||||||
|
),
|
||||||
"prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
|
"prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
|
||||||
"prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)),
|
"prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)),
|
||||||
"prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
"prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user