mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-13 13:28:15 +08:00
Enhance TTS processing with new reference specification handling and profiling metrics
Refactor the PrepareCoordinator and related components to improve the handling of reference specifications in the TTS system. Introduce new methods for building and extracting reference prompts and specifications, along with detailed profiling metrics for performance monitoring. Update the PrepareRefSemanticBatchWorker to include additional timing metrics and caching mechanisms for resampling. These changes enhance the efficiency and maintainability of the TTS framework, particularly in managing audio processing and reference data.
This commit is contained in:
parent
c94de2f2cb
commit
8a444c10b7
@ -996,21 +996,39 @@ class TTS:
|
||||
return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr)
|
||||
|
||||
def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
|
||||
spec, audio, _, _, _ = self._extract_ref_spec_profile_from_raw(raw_audio, raw_sr)
|
||||
return spec, audio, raw_audio, raw_sr
|
||||
|
||||
def _extract_ref_spec_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
|
||||
profile = {
|
||||
"ref_spec_to_device_ms": 0.0,
|
||||
"ref_spec_main_resample_ms": 0.0,
|
||||
"ref_spec_norm_ms": 0.0,
|
||||
"ref_spec_spectrogram_ms": 0.0,
|
||||
"ref_spec_post_resample_ms": 0.0,
|
||||
}
|
||||
to_device_start = time.perf_counter()
|
||||
raw_audio_device = raw_audio.to(self.configs.device).float()
|
||||
profile["ref_spec_to_device_ms"] = (time.perf_counter() - to_device_start) * 1000.0
|
||||
|
||||
if raw_sr != self.configs.sampling_rate:
|
||||
resample_start = time.perf_counter()
|
||||
audio = raw_audio_device
|
||||
if audio.shape[0] == 2:
|
||||
audio = audio.mean(0).unsqueeze(0)
|
||||
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
|
||||
profile["ref_spec_main_resample_ms"] = (time.perf_counter() - resample_start) * 1000.0
|
||||
else:
|
||||
audio = raw_audio_device
|
||||
if audio.shape[0] == 2:
|
||||
audio = audio.mean(0).unsqueeze(0)
|
||||
|
||||
norm_start = time.perf_counter()
|
||||
maxx = audio.abs().max()
|
||||
if maxx > 1:
|
||||
audio /= min(2, maxx)
|
||||
profile["ref_spec_norm_ms"] = (time.perf_counter() - norm_start) * 1000.0
|
||||
spec_start = time.perf_counter()
|
||||
spec = spectrogram_torch(
|
||||
audio,
|
||||
self.configs.filter_length,
|
||||
@ -1019,15 +1037,18 @@ class TTS:
|
||||
self.configs.win_length,
|
||||
center=False,
|
||||
)
|
||||
profile["ref_spec_spectrogram_ms"] = (time.perf_counter() - spec_start) * 1000.0
|
||||
if self.configs.is_half:
|
||||
spec = spec.half()
|
||||
if self.is_v2pro == True:
|
||||
post_resample_start = time.perf_counter()
|
||||
audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device)
|
||||
profile["ref_spec_post_resample_ms"] = (time.perf_counter() - post_resample_start) * 1000.0
|
||||
if self.configs.is_half:
|
||||
audio = audio.half()
|
||||
else:
|
||||
audio = None
|
||||
return spec, audio, raw_audio, raw_sr
|
||||
return spec, audio, raw_audio, raw_sr, profile
|
||||
|
||||
def extract_ref_spec(self, ref_audio_path: str):
|
||||
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
|
||||
|
||||
@ -228,8 +228,74 @@ class PrepareCoordinator:
|
||||
def _load_ref_audio_raw(self, ref_audio_path: str):
|
||||
return self.tts._load_ref_audio_raw(ref_audio_path)
|
||||
|
||||
def _build_ref_prompt_semantic_from_raw(self, raw_audio, raw_sr: int):
|
||||
load_profile = {"audio_load_ms": 0.0}
|
||||
if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None:
|
||||
prompt_semantic, worker_profile = self.tts.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr)
|
||||
return {
|
||||
"prompt_semantic": prompt_semantic,
|
||||
"raw_audio": raw_audio,
|
||||
"raw_sr": raw_sr,
|
||||
"profile": {
|
||||
**load_profile,
|
||||
"audio_stage_wait_ms": float(worker_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"audio_stage_slots": float(worker_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||||
"audio_stage_inflight_peak": float(worker_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
|
||||
"prompt_semantic_ms": float(
|
||||
worker_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
|
||||
+ worker_profile.get("prompt_semantic_forward_ms", 0.0)
|
||||
+ worker_profile.get("prompt_semantic_scatter_ms", 0.0)
|
||||
),
|
||||
**{key: float(value) for key, value in worker_profile.items()},
|
||||
"ref_spec_wait_ms": 0.0,
|
||||
"ref_spec_ms": 0.0,
|
||||
"bundle_total_ms": float(worker_profile.get("prompt_semantic_wait_ms", 0.0))
|
||||
+ float(worker_profile.get("prompt_semantic_cpu_prepare_ms", 0.0))
|
||||
+ float(worker_profile.get("prompt_semantic_forward_ms", 0.0))
|
||||
+ float(worker_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
||||
},
|
||||
}
|
||||
wav16k, cpu_prepare_ms, limiter_stats = self.tts._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr)
|
||||
with self.tts.prepare_ref_audio_stage_limiter.enter() as stage_stats:
|
||||
prompt_semantic, forward_ms = self.tts._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k)
|
||||
return {
|
||||
"prompt_semantic": prompt_semantic,
|
||||
"raw_audio": raw_audio,
|
||||
"raw_sr": raw_sr,
|
||||
"profile": {
|
||||
"audio_load_ms": 0.0,
|
||||
"audio_stage_wait_ms": float(stage_stats.get("wait_ms", 0.0)),
|
||||
"audio_stage_slots": float(stage_stats.get("slots", 0.0)),
|
||||
"audio_stage_inflight_peak": float(stage_stats.get("peak_inflight", 0.0)),
|
||||
"prompt_semantic_wait_ms": float(stage_stats.get("wait_ms", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_wait_ms": float(limiter_stats.get("wait_ms", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_slots": float(limiter_stats.get("slots", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_inflight_peak": float(limiter_stats.get("peak_inflight", 0.0)),
|
||||
"prompt_semantic_worker_queue_wait_ms": 0.0,
|
||||
"prompt_semantic_batch_collect_wait_ms": 0.0,
|
||||
"prompt_semantic_stage_limiter_wait_ms": float(stage_stats.get("wait_ms", 0.0)),
|
||||
"prompt_semantic_batch_dispatch_delay_ms": 0.0,
|
||||
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
|
||||
"prompt_semantic_pack_ms": 0.0,
|
||||
"prompt_semantic_h2d_ms": 0.0,
|
||||
"prompt_semantic_ssl_forward_ms": 0.0,
|
||||
"prompt_semantic_hidden_length_ms": 0.0,
|
||||
"prompt_semantic_extract_latent_ms": 0.0,
|
||||
"prompt_semantic_forward_ms": float(forward_ms),
|
||||
"prompt_semantic_scatter_ms": 0.0,
|
||||
"prompt_semantic_stage_slots": float(stage_stats.get("slots", 0.0)),
|
||||
"prompt_semantic_stage_inflight_peak": float(stage_stats.get("peak_inflight", 0.0)),
|
||||
"prompt_semantic_batch_size": 1.0,
|
||||
"prompt_semantic_batch_samples": 0.0,
|
||||
"ref_spec_wait_ms": 0.0,
|
||||
"ref_spec_ms": 0.0,
|
||||
"bundle_total_ms": float(cpu_prepare_ms + forward_ms + stage_stats.get("wait_ms", 0.0)),
|
||||
},
|
||||
}
|
||||
|
||||
def _extract_ref_spec_from_raw(self, raw_audio, raw_sr: int):
|
||||
return self.tts._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
|
||||
spec, audio, _, _, profile = self.tts._extract_ref_spec_profile_from_raw(raw_audio, raw_sr)
|
||||
return (spec, audio), profile
|
||||
|
||||
@staticmethod
|
||||
def _build_empty_text_features_like(reference: PreparedTextFeatures | None = None) -> PreparedTextFeatures:
|
||||
@ -523,7 +589,7 @@ class PrepareCoordinator:
|
||||
finally:
|
||||
self.text_feature_gate.release()
|
||||
|
||||
async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult:
|
||||
async def _run_ref_prompt_semantic_stage(self, ref_audio_path: str) -> ProfiledResult:
|
||||
if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None:
|
||||
submit_at = time.perf_counter()
|
||||
started_at = float(submit_at)
|
||||
@ -538,19 +604,7 @@ class PrepareCoordinator:
|
||||
prompt_semantic_task = asyncio.create_task(
|
||||
self.tts.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr)
|
||||
)
|
||||
await self.ref_spec_gate.acquire()
|
||||
try:
|
||||
ref_spec_task = asyncio.create_task(
|
||||
self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr)
|
||||
)
|
||||
(prompt_semantic, prompt_semantic_profile), ref_spec_profiled = await asyncio.gather(
|
||||
prompt_semantic_task,
|
||||
ref_spec_task,
|
||||
)
|
||||
finally:
|
||||
self.ref_spec_gate.release()
|
||||
|
||||
refer_spec = ref_spec_profiled.result
|
||||
prompt_semantic, prompt_semantic_profile = await prompt_semantic_task
|
||||
limiter_snapshot = (
|
||||
self.tts.prepare_ref_audio_stage_limiter.snapshot()
|
||||
if getattr(self.tts, "prepare_ref_audio_stage_limiter", None) is not None
|
||||
@ -561,21 +615,15 @@ class PrepareCoordinator:
|
||||
+ float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0))
|
||||
+ float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0))
|
||||
)
|
||||
audio_stage_wait_ms = (
|
||||
float(load_profiled.queue_ms)
|
||||
+ float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0))
|
||||
+ float(ref_spec_profiled.queue_ms)
|
||||
)
|
||||
finished_at = time.perf_counter()
|
||||
result = {
|
||||
"prompt_semantic": prompt_semantic,
|
||||
"refer_spec": refer_spec,
|
||||
"raw_audio": raw_audio,
|
||||
"raw_sr": raw_sr,
|
||||
"profile": {
|
||||
"audio_load_queue_ms": float(load_profiled.queue_ms),
|
||||
"audio_load_ms": float(load_profiled.run_ms),
|
||||
"audio_stage_wait_ms": float(audio_stage_wait_ms),
|
||||
"audio_stage_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"audio_stage_slots": float(
|
||||
max(
|
||||
float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||||
@ -593,6 +641,17 @@ class PrepareCoordinator:
|
||||
"prompt_semantic_cpu_prepare_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_pack_ms": float(prompt_semantic_profile.get("prompt_semantic_pack_ms", 0.0)),
|
||||
"prompt_semantic_h2d_ms": float(prompt_semantic_profile.get("prompt_semantic_h2d_ms", 0.0)),
|
||||
"prompt_semantic_ssl_forward_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_ssl_forward_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_hidden_length_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_hidden_length_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_extract_latent_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_extract_latent_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
|
||||
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
||||
"prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||||
@ -603,14 +662,10 @@ class PrepareCoordinator:
|
||||
"prompt_semantic_batch_samples": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)
|
||||
),
|
||||
"ref_spec_wait_ms": float(ref_spec_profiled.queue_ms),
|
||||
"ref_spec_ms": float(ref_spec_profiled.run_ms),
|
||||
"bundle_total_ms": float(
|
||||
load_profiled.queue_ms
|
||||
+ load_profiled.run_ms
|
||||
+ prompt_semantic_ms
|
||||
+ ref_spec_profiled.queue_ms
|
||||
+ ref_spec_profiled.run_ms
|
||||
),
|
||||
},
|
||||
}
|
||||
@ -623,21 +678,26 @@ class PrepareCoordinator:
|
||||
|
||||
await self.ref_audio_gate.acquire()
|
||||
try:
|
||||
if hasattr(self.tts, "extract_ref_audio_bundle_async"):
|
||||
submit_at = time.perf_counter()
|
||||
started_at = time.perf_counter()
|
||||
result = await self.tts.extract_ref_audio_bundle_async(ref_audio_path)
|
||||
finished_at = time.perf_counter()
|
||||
return ProfiledResult(
|
||||
result=result,
|
||||
submit_at=float(submit_at),
|
||||
started_at=float(started_at),
|
||||
finished_at=float(finished_at),
|
||||
)
|
||||
return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path)
|
||||
load_profiled = await self._run_on_executor(self.ref_audio_executor, self._load_ref_audio_raw, ref_audio_path)
|
||||
raw_audio, raw_sr = load_profiled.result
|
||||
submit_at = time.perf_counter()
|
||||
started_at = time.perf_counter()
|
||||
result = await asyncio.to_thread(self._build_ref_prompt_semantic_from_raw, raw_audio, raw_sr)
|
||||
result.setdefault("profile", {})
|
||||
result["profile"]["audio_load_queue_ms"] = float(load_profiled.queue_ms)
|
||||
result["profile"]["audio_load_ms"] = float(load_profiled.run_ms)
|
||||
finished_at = time.perf_counter()
|
||||
return ProfiledResult(result=result, submit_at=float(submit_at), started_at=float(started_at), finished_at=float(finished_at))
|
||||
finally:
|
||||
self.ref_audio_gate.release()
|
||||
|
||||
async def _run_ref_spec_stage(self, raw_audio, raw_sr: int) -> ProfiledResult:
|
||||
await self.ref_spec_gate.acquire()
|
||||
try:
|
||||
return await self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr)
|
||||
finally:
|
||||
self.ref_spec_gate.release()
|
||||
|
||||
def _release_split_stage_slot(self) -> None:
|
||||
self._mark_leave()
|
||||
self._inflight_gate.release()
|
||||
@ -682,101 +742,318 @@ class PrepareCoordinator:
|
||||
cpu_stage: PreparedCpuStage,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
try:
|
||||
g2pw_pair_start = time.perf_counter()
|
||||
g2pw_pair_task = asyncio.create_task(
|
||||
self._run_g2pw_pair_stage(
|
||||
cpu_stage.prompt_cpu_profiled.result,
|
||||
cpu_stage.target_cpu_profiled.result,
|
||||
)
|
||||
)
|
||||
ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(cpu_stage.spec.ref_audio_path)))
|
||||
(prompt_g2pw_profiled, target_g2pw_profiled), ref_audio_profiled = await asyncio.gather(
|
||||
g2pw_pair_task,
|
||||
ref_audio_task,
|
||||
)
|
||||
g2pw_pair_end = time.perf_counter()
|
||||
text_pair_start = time.perf_counter()
|
||||
text_feature_pair_task = asyncio.create_task(
|
||||
self._run_text_feature_pair_stage(
|
||||
prompt_g2pw_profiled.result,
|
||||
target_g2pw_profiled.result,
|
||||
cpu_stage.prompt_cpu_profiled.run_ms,
|
||||
cpu_stage.target_cpu_profiled.run_ms,
|
||||
prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}),
|
||||
target_base_profile=dict(target_g2pw_profiled.profile or {}),
|
||||
)
|
||||
)
|
||||
prompt_feature_profiled, target_feature_profiled = await text_feature_pair_task
|
||||
text_pair_end = time.perf_counter()
|
||||
state = build_request_state_from_parts(
|
||||
tts=self.tts,
|
||||
spec=cpu_stage.spec,
|
||||
prompt_text=cpu_stage.prompt_text,
|
||||
text=cpu_stage.text,
|
||||
prompt_result=prompt_feature_profiled.result,
|
||||
target_result=target_feature_profiled.result,
|
||||
ref_audio_bundle=ref_audio_profiled.result,
|
||||
prepare_start=cpu_stage.prepare_start,
|
||||
prepare_sync_start=cpu_stage.prepare_start,
|
||||
profile_overrides={
|
||||
"executor_queue_ms": max(0.0, (cpu_stage.prepare_start - cpu_stage.prepare_submit_at) * 1000.0),
|
||||
"prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms,
|
||||
"executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0),
|
||||
"text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0),
|
||||
"g2pw_pair_ms": max(0.0, (g2pw_pair_end - g2pw_pair_start) * 1000.0),
|
||||
"prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms,
|
||||
"prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms,
|
||||
"prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)),
|
||||
"prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)),
|
||||
"prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)),
|
||||
"text_g2pw_queue_ms": target_g2pw_profiled.queue_ms,
|
||||
"text_g2pw_run_ms": target_g2pw_profiled.run_ms,
|
||||
"text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)),
|
||||
"text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)),
|
||||
"text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)),
|
||||
"prompt_text_parallel_future_wait_ms": 0.0,
|
||||
"prompt_text_parallel_future_executor_queue_ms": 0.0,
|
||||
"prompt_text_parallel_future_run_ms": 0.0,
|
||||
"prompt_text_parallel_future_finish_after_submit_ms": 0.0,
|
||||
"prompt_text_parallel_future_queue_tail_after_target_ms": 0.0,
|
||||
"prompt_text_parallel_future_run_tail_after_target_ms": 0.0,
|
||||
"prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms,
|
||||
"prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms,
|
||||
"prompt_text_cpu_admission_wait_ms": float(
|
||||
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_text_cpu_backpressure_wait_ms": float(
|
||||
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_text_cpu_capacity_wait_ms": float(
|
||||
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms,
|
||||
"prompt_text_feature_run_ms": prompt_feature_profiled.run_ms,
|
||||
"text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms,
|
||||
"text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms,
|
||||
"text_cpu_admission_wait_ms": float(
|
||||
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0)
|
||||
),
|
||||
"text_cpu_backpressure_wait_ms": float(
|
||||
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0)
|
||||
),
|
||||
"text_cpu_capacity_wait_ms": float(
|
||||
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0)
|
||||
),
|
||||
"text_feature_queue_ms": target_feature_profiled.queue_ms,
|
||||
"text_feature_run_ms": target_feature_profiled.run_ms,
|
||||
"ref_audio_task_queue_ms": ref_audio_profiled.queue_ms,
|
||||
"ref_audio_task_run_ms": ref_audio_profiled.run_ms,
|
||||
"worker_prepare_inflight_on_enter": float(cpu_stage.current_inflight),
|
||||
"worker_prepare_peak_inflight": float(cpu_stage.peak_inflight),
|
||||
phase_one = await self._prepare_gpu_phase_one(cpu_stage)
|
||||
phase_two = await self._prepare_gpu_phase_two(cpu_stage, phase_one)
|
||||
return self._build_gpu_prepare_result(
|
||||
cpu_stage,
|
||||
phase_one,
|
||||
phase_two,
|
||||
extra_profile={
|
||||
"engine_prepare_audio_phase_mode": 0.0,
|
||||
"engine_prepare_audio_phase_wall_ms": float(phase_one["phase_wall_ms"]),
|
||||
"engine_prepare_audio_phase_batch_size": 1.0,
|
||||
"engine_prepare_text_phase_wall_ms": float(phase_two["phase_wall_ms"]),
|
||||
"engine_prepare_text_phase_batch_size": 1.0,
|
||||
},
|
||||
)
|
||||
prepare_exec_finished_at = time.perf_counter()
|
||||
state.prepare_profile["executor_run_wall_ms"] = max(
|
||||
0.0, (prepare_exec_finished_at - cpu_stage.prepare_start) * 1000.0
|
||||
finally:
|
||||
self._release_split_stage_slot()
|
||||
|
||||
async def _prepare_gpu_phase_one(self, cpu_stage: PreparedCpuStage) -> Dict[str, Any]:
|
||||
phase_start = time.perf_counter()
|
||||
g2pw_pair_task = asyncio.create_task(
|
||||
self._run_g2pw_pair_stage(
|
||||
cpu_stage.prompt_cpu_profiled.result,
|
||||
cpu_stage.target_cpu_profiled.result,
|
||||
)
|
||||
return state, cpu_stage.prepare_start, prepare_exec_finished_at
|
||||
)
|
||||
ref_audio_task = asyncio.create_task(self._run_ref_prompt_semantic_stage(str(cpu_stage.spec.ref_audio_path)))
|
||||
prompt_g2pw_profiled, target_g2pw_profiled = await g2pw_pair_task
|
||||
g2pw_pair_end = time.perf_counter()
|
||||
ref_audio_profiled = await ref_audio_task
|
||||
phase_end = time.perf_counter()
|
||||
return {
|
||||
"prompt_g2pw_profiled": prompt_g2pw_profiled,
|
||||
"target_g2pw_profiled": target_g2pw_profiled,
|
||||
"ref_audio_profiled": ref_audio_profiled,
|
||||
"ref_spec_result": None,
|
||||
"g2pw_pair_ms": max(0.0, (g2pw_pair_end - phase_start) * 1000.0),
|
||||
"phase_wall_ms": max(0.0, (phase_end - phase_start) * 1000.0),
|
||||
}
|
||||
|
||||
async def _prepare_gpu_phase_two(
|
||||
self,
|
||||
cpu_stage: PreparedCpuStage,
|
||||
phase_one: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
phase_start = time.perf_counter()
|
||||
prompt_g2pw_profiled = phase_one["prompt_g2pw_profiled"]
|
||||
target_g2pw_profiled = phase_one["target_g2pw_profiled"]
|
||||
prompt_feature_profiled, target_feature_profiled = await self._run_text_feature_pair_stage(
|
||||
prompt_g2pw_profiled.result,
|
||||
target_g2pw_profiled.result,
|
||||
cpu_stage.prompt_cpu_profiled.run_ms,
|
||||
cpu_stage.target_cpu_profiled.run_ms,
|
||||
prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}),
|
||||
target_base_profile=dict(target_g2pw_profiled.profile or {}),
|
||||
)
|
||||
phase_end = time.perf_counter()
|
||||
return {
|
||||
"prompt_feature_profiled": prompt_feature_profiled,
|
||||
"target_feature_profiled": target_feature_profiled,
|
||||
"phase_wall_ms": max(0.0, (phase_end - phase_start) * 1000.0),
|
||||
}
|
||||
|
||||
def _build_gpu_prepare_result(
|
||||
self,
|
||||
cpu_stage: PreparedCpuStage,
|
||||
phase_one: Dict[str, Any],
|
||||
phase_two: Dict[str, Any],
|
||||
extra_profile: Dict[str, float] | None = None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
prompt_g2pw_profiled = phase_one["prompt_g2pw_profiled"]
|
||||
target_g2pw_profiled = phase_one["target_g2pw_profiled"]
|
||||
ref_audio_profiled = phase_one["ref_audio_profiled"]
|
||||
ref_spec_result = phase_one.get("ref_spec_result")
|
||||
prompt_feature_profiled = phase_two["prompt_feature_profiled"]
|
||||
target_feature_profiled = phase_two["target_feature_profiled"]
|
||||
profile_overrides = {
|
||||
"executor_queue_ms": max(0.0, (cpu_stage.prepare_start - cpu_stage.prepare_submit_at) * 1000.0),
|
||||
"prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms,
|
||||
"prepare_submit_ts": float(cpu_stage.prepare_submit_at),
|
||||
"prepare_cpu_start_ts": float(cpu_stage.prepare_start),
|
||||
"prepare_cpu_done_ts": float(
|
||||
max(cpu_stage.prompt_cpu_profiled.finished_at, cpu_stage.target_cpu_profiled.finished_at)
|
||||
),
|
||||
"prompt_text_cpu_start_ts": float(cpu_stage.prompt_cpu_profiled.started_at),
|
||||
"prompt_text_cpu_end_ts": float(cpu_stage.prompt_cpu_profiled.finished_at),
|
||||
"text_cpu_start_ts": float(cpu_stage.target_cpu_profiled.started_at),
|
||||
"text_cpu_end_ts": float(cpu_stage.target_cpu_profiled.finished_at),
|
||||
"executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0),
|
||||
"text_feature_pair_ms": float(phase_two["phase_wall_ms"]),
|
||||
"g2pw_pair_ms": float(phase_one["g2pw_pair_ms"]),
|
||||
"prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms,
|
||||
"prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms,
|
||||
"prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)),
|
||||
"prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)),
|
||||
"prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)),
|
||||
"text_g2pw_queue_ms": target_g2pw_profiled.queue_ms,
|
||||
"text_g2pw_run_ms": target_g2pw_profiled.run_ms,
|
||||
"text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)),
|
||||
"text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)),
|
||||
"text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)),
|
||||
"prompt_text_parallel_future_wait_ms": 0.0,
|
||||
"prompt_text_parallel_future_executor_queue_ms": 0.0,
|
||||
"prompt_text_parallel_future_run_ms": 0.0,
|
||||
"prompt_text_parallel_future_finish_after_submit_ms": 0.0,
|
||||
"prompt_text_parallel_future_queue_tail_after_target_ms": 0.0,
|
||||
"prompt_text_parallel_future_run_tail_after_target_ms": 0.0,
|
||||
"prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms,
|
||||
"prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms,
|
||||
"prompt_text_cpu_admission_wait_ms": float(
|
||||
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_text_cpu_backpressure_wait_ms": float(
|
||||
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_text_cpu_capacity_wait_ms": float(
|
||||
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms,
|
||||
"prompt_text_feature_run_ms": prompt_feature_profiled.run_ms,
|
||||
"text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms,
|
||||
"text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms,
|
||||
"text_cpu_admission_wait_ms": float(
|
||||
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0)
|
||||
),
|
||||
"text_cpu_backpressure_wait_ms": float(
|
||||
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0)
|
||||
),
|
||||
"text_cpu_capacity_wait_ms": float(
|
||||
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0)
|
||||
),
|
||||
"text_feature_queue_ms": target_feature_profiled.queue_ms,
|
||||
"text_feature_run_ms": target_feature_profiled.run_ms,
|
||||
"ref_audio_task_queue_ms": ref_audio_profiled.queue_ms,
|
||||
"ref_audio_task_run_ms": ref_audio_profiled.run_ms,
|
||||
"worker_prepare_inflight_on_enter": float(cpu_stage.current_inflight),
|
||||
"worker_prepare_peak_inflight": float(cpu_stage.peak_inflight),
|
||||
}
|
||||
if extra_profile:
|
||||
profile_overrides.update({key: float(value) for key, value in extra_profile.items()})
|
||||
ref_audio_bundle = dict(ref_audio_profiled.result)
|
||||
ref_audio_profile = dict(ref_audio_bundle.get("profile", {}))
|
||||
if ref_spec_result is not None:
|
||||
refer_spec, ref_spec_profiled = ref_spec_result
|
||||
ref_audio_bundle["refer_spec"] = refer_spec
|
||||
ref_audio_profile.update(
|
||||
{
|
||||
"ref_spec_wait_ms": float(ref_spec_profiled.get("ref_spec_wait_ms", 0.0)),
|
||||
"ref_spec_ms": float(ref_spec_profiled.get("ref_spec_ms", 0.0)),
|
||||
"ref_spec_to_device_ms": float(ref_spec_profiled.get("ref_spec_to_device_ms", 0.0)),
|
||||
"ref_spec_main_resample_ms": float(ref_spec_profiled.get("ref_spec_main_resample_ms", 0.0)),
|
||||
"ref_spec_norm_ms": float(ref_spec_profiled.get("ref_spec_norm_ms", 0.0)),
|
||||
"ref_spec_spectrogram_ms": float(ref_spec_profiled.get("ref_spec_spectrogram_ms", 0.0)),
|
||||
"ref_spec_post_resample_ms": float(ref_spec_profiled.get("ref_spec_post_resample_ms", 0.0)),
|
||||
}
|
||||
)
|
||||
else:
|
||||
ref_audio_bundle["refer_spec"] = None
|
||||
ref_audio_profile.setdefault("ref_spec_wait_ms", 0.0)
|
||||
ref_audio_profile.setdefault("ref_spec_ms", 0.0)
|
||||
ref_audio_profile.setdefault("ref_spec_to_device_ms", 0.0)
|
||||
ref_audio_profile.setdefault("ref_spec_main_resample_ms", 0.0)
|
||||
ref_audio_profile.setdefault("ref_spec_norm_ms", 0.0)
|
||||
ref_audio_profile.setdefault("ref_spec_spectrogram_ms", 0.0)
|
||||
ref_audio_profile.setdefault("ref_spec_post_resample_ms", 0.0)
|
||||
ref_audio_bundle["profile"] = ref_audio_profile
|
||||
state = build_request_state_from_parts(
|
||||
tts=self.tts,
|
||||
spec=cpu_stage.spec,
|
||||
prompt_text=cpu_stage.prompt_text,
|
||||
text=cpu_stage.text,
|
||||
prompt_result=prompt_feature_profiled.result,
|
||||
target_result=target_feature_profiled.result,
|
||||
ref_audio_bundle=ref_audio_bundle,
|
||||
prepare_start=cpu_stage.prepare_start,
|
||||
prepare_sync_start=cpu_stage.prepare_start,
|
||||
profile_overrides=profile_overrides,
|
||||
)
|
||||
prepare_exec_finished_at = time.perf_counter()
|
||||
state.prepare_profile["executor_run_wall_ms"] = max(0.0, (prepare_exec_finished_at - cpu_stage.prepare_start) * 1000.0)
|
||||
return state, cpu_stage.prepare_start, prepare_exec_finished_at
|
||||
|
||||
async def prepare_ref_spec_stages_async(
|
||||
self,
|
||||
phase_ones: list[Dict[str, Any]],
|
||||
) -> list[tuple[tuple[Any, Any], Dict[str, float]] | Exception]:
|
||||
async def _one(phase_one: Dict[str, Any]):
|
||||
ref_audio_profiled = phase_one["ref_audio_profiled"]
|
||||
raw_audio = ref_audio_profiled.result["raw_audio"]
|
||||
raw_sr = int(ref_audio_profiled.result["raw_sr"])
|
||||
profiled = await self._run_ref_spec_stage(raw_audio, raw_sr)
|
||||
refer_spec, profile = profiled.result
|
||||
merged_profile = dict(profile)
|
||||
merged_profile["ref_spec_wait_ms"] = float(profiled.queue_ms)
|
||||
merged_profile["ref_spec_ms"] = float(profiled.run_ms)
|
||||
return refer_spec, merged_profile
|
||||
|
||||
if not phase_ones:
|
||||
return []
|
||||
return list(await asyncio.gather(*[_one(phase_one) for phase_one in phase_ones], return_exceptions=True))
|
||||
|
||||
def apply_ref_spec_result_to_state(
|
||||
self,
|
||||
state: T2SRequestState,
|
||||
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]],
|
||||
) -> None:
|
||||
refer_spec, profile = ref_spec_result
|
||||
state.refer_spec = refer_spec
|
||||
state.prepare_profile["ref_spec_wait_ms"] = float(profile.get("ref_spec_wait_ms", 0.0))
|
||||
state.prepare_profile["ref_spec_ms"] = float(profile.get("ref_spec_ms", 0.0))
|
||||
state.prepare_profile["ref_spec_to_device_ms"] = float(profile.get("ref_spec_to_device_ms", 0.0))
|
||||
state.prepare_profile["ref_spec_main_resample_ms"] = float(profile.get("ref_spec_main_resample_ms", 0.0))
|
||||
state.prepare_profile["ref_spec_norm_ms"] = float(profile.get("ref_spec_norm_ms", 0.0))
|
||||
state.prepare_profile["ref_spec_spectrogram_ms"] = float(profile.get("ref_spec_spectrogram_ms", 0.0))
|
||||
state.prepare_profile["ref_spec_post_resample_ms"] = float(profile.get("ref_spec_post_resample_ms", 0.0))
|
||||
|
||||
async def prepare_gpu_stages_profiled_async(
|
||||
self,
|
||||
cpu_stages: list[PreparedCpuStage],
|
||||
) -> list[tuple[T2SRequestState, float, float] | Exception]:
|
||||
if not cpu_stages:
|
||||
return []
|
||||
if len(cpu_stages) == 1:
|
||||
single_stage = cpu_stages[0]
|
||||
try:
|
||||
return [await self.prepare_gpu_stage_profiled_async(single_stage)]
|
||||
except Exception as exc: # noqa: PERF203
|
||||
return [exc]
|
||||
|
||||
phase_one_started_at = time.perf_counter()
|
||||
phase_one_results = await asyncio.gather(
|
||||
*[self._prepare_gpu_phase_one(cpu_stage) for cpu_stage in cpu_stages],
|
||||
return_exceptions=True,
|
||||
)
|
||||
phase_one_finished_at = time.perf_counter()
|
||||
phase_one_wall_ms = max(0.0, (phase_one_finished_at - phase_one_started_at) * 1000.0)
|
||||
|
||||
outputs: list[tuple[T2SRequestState, float, float] | Exception | None] = [None] * len(cpu_stages)
|
||||
pending_phase_two: list[tuple[int, PreparedCpuStage, Dict[str, Any]]] = []
|
||||
for index, (cpu_stage, phase_one) in enumerate(zip(cpu_stages, phase_one_results)):
|
||||
if isinstance(phase_one, Exception):
|
||||
outputs[index] = phase_one
|
||||
self._release_split_stage_slot()
|
||||
continue
|
||||
pending_phase_two.append((index, cpu_stage, phase_one))
|
||||
|
||||
phase_two_started_at = time.perf_counter()
|
||||
phase_two_results = await asyncio.gather(
|
||||
*[self._prepare_gpu_phase_two(cpu_stage, phase_one) for _, cpu_stage, phase_one in pending_phase_two],
|
||||
return_exceptions=True,
|
||||
)
|
||||
phase_two_finished_at = time.perf_counter()
|
||||
phase_two_wall_ms = max(0.0, (phase_two_finished_at - phase_two_started_at) * 1000.0)
|
||||
|
||||
for (index, cpu_stage, phase_one), phase_two in zip(pending_phase_two, phase_two_results):
|
||||
try:
|
||||
if isinstance(phase_two, Exception):
|
||||
outputs[index] = phase_two
|
||||
continue
|
||||
outputs[index] = self._build_gpu_prepare_result(
|
||||
cpu_stage,
|
||||
phase_one,
|
||||
phase_two,
|
||||
extra_profile={
|
||||
"engine_prepare_audio_phase_mode": 1.0,
|
||||
"engine_prepare_audio_phase_wall_ms": float(phase_one_wall_ms),
|
||||
"engine_prepare_audio_phase_batch_size": float(len(cpu_stages)),
|
||||
"engine_prepare_text_phase_wall_ms": float(phase_two_wall_ms),
|
||||
"engine_prepare_text_phase_batch_size": float(len(pending_phase_two)),
|
||||
},
|
||||
)
|
||||
except Exception as exc: # noqa: PERF203
|
||||
outputs[index] = exc
|
||||
finally:
|
||||
self._release_split_stage_slot()
|
||||
|
||||
return [item if item is not None else RuntimeError("prepare batch result missing") for item in outputs]
|
||||
|
||||
async def prepare_gpu_audio_phases_async(
|
||||
self,
|
||||
cpu_stages: list[PreparedCpuStage],
|
||||
) -> list[Dict[str, Any] | Exception]:
|
||||
if not cpu_stages:
|
||||
return []
|
||||
return list(
|
||||
await asyncio.gather(
|
||||
*[self._prepare_gpu_phase_one(cpu_stage) for cpu_stage in cpu_stages],
|
||||
return_exceptions=True,
|
||||
)
|
||||
)
|
||||
|
||||
async def prepare_gpu_text_phases_async(
|
||||
self,
|
||||
items: list[tuple[PreparedCpuStage, Dict[str, Any]]],
|
||||
) -> list[Dict[str, Any] | Exception]:
|
||||
if not items:
|
||||
return []
|
||||
return list(
|
||||
await asyncio.gather(
|
||||
*[self._prepare_gpu_phase_two(cpu_stage, phase_one) for cpu_stage, phase_one in items],
|
||||
return_exceptions=True,
|
||||
)
|
||||
)
|
||||
|
||||
def build_gpu_prepare_result_from_phases(
|
||||
self,
|
||||
cpu_stage: PreparedCpuStage,
|
||||
phase_one: Dict[str, Any],
|
||||
phase_two: Dict[str, Any],
|
||||
extra_profile: Dict[str, float] | None = None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
try:
|
||||
return self._build_gpu_prepare_result(cpu_stage, phase_one, phase_two, extra_profile=extra_profile)
|
||||
finally:
|
||||
self._release_split_stage_slot()
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ REF_AUDIO_MIN_SAMPLES_16K = 48000
|
||||
REF_AUDIO_MAX_SAMPLES_16K = 160000
|
||||
_RESAMPLE_CACHE_LOCK = threading.Lock()
|
||||
_RESAMPLE_CACHE: Dict[Tuple[int, int, str], torchaudio.transforms.Resample] = {}
|
||||
_RESAMPLE_STREAM_CACHE: Dict[str, torch.cuda.Stream] = {}
|
||||
|
||||
|
||||
def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.transforms.Resample:
|
||||
@ -28,6 +29,16 @@ def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.tran
|
||||
return transform
|
||||
|
||||
|
||||
def _get_resample_stream(device: str) -> torch.cuda.Stream:
|
||||
device_key = str(device)
|
||||
with _RESAMPLE_CACHE_LOCK:
|
||||
stream = _RESAMPLE_STREAM_CACHE.get(device_key)
|
||||
if stream is None:
|
||||
stream = torch.cuda.Stream(device=device_key)
|
||||
_RESAMPLE_STREAM_CACHE[device_key] = stream
|
||||
return stream
|
||||
|
||||
|
||||
def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor:
|
||||
resample_device = os.environ.get("GPTSOVITS_PREPARE_REF_RESAMPLE_DEVICE", "cpu").strip().lower() or "cpu"
|
||||
if resample_device not in {"cpu", "cuda"}:
|
||||
@ -37,10 +48,20 @@ def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wa
|
||||
wav_mono = raw_audio
|
||||
if wav_mono.dim() == 2 and wav_mono.shape[0] != 1:
|
||||
wav_mono = wav_mono.mean(0, keepdim=True)
|
||||
wav16k = wav_mono.to(dtype=torch.float32, device=resample_device)
|
||||
if raw_sr != 16000:
|
||||
wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k)
|
||||
wav16k = wav16k.squeeze(0).contiguous()
|
||||
if resample_device == "cuda":
|
||||
stream = _get_resample_stream(resample_device)
|
||||
with torch.cuda.stream(stream):
|
||||
wav16k = wav_mono.to(dtype=torch.float32, device=resample_device)
|
||||
if raw_sr != 16000:
|
||||
wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k)
|
||||
wav16k = wav16k.squeeze(0).contiguous()
|
||||
stream.synchronize()
|
||||
wav16k = wav16k.detach().to(device="cpu", dtype=torch.float32).contiguous()
|
||||
else:
|
||||
wav16k = wav_mono.to(dtype=torch.float32, device=resample_device)
|
||||
if raw_sr != 16000:
|
||||
wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k)
|
||||
wav16k = wav16k.squeeze(0).contiguous()
|
||||
if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K:
|
||||
raise OSError("参考音频在3~10秒范围外,请更换!")
|
||||
if zero_wav_samples > 0:
|
||||
@ -256,37 +277,56 @@ class PrepareRefSemanticBatchWorker:
|
||||
batch_samples = int(wav_lengths.sum().item())
|
||||
max_wav_len = int(wav_lengths.max().item())
|
||||
|
||||
pack_start = time.perf_counter()
|
||||
input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32)
|
||||
attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long)
|
||||
for batch_index, wav in enumerate(prepared_wavs):
|
||||
wav_len = int(wav.shape[0])
|
||||
input_values_cpu[batch_index, :wav_len] = wav
|
||||
attention_mask_cpu[batch_index, :wav_len] = 1
|
||||
pack_ms = (time.perf_counter() - pack_start) * 1000.0
|
||||
|
||||
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
|
||||
h2d_ms = 0.0
|
||||
ssl_forward_ms = 0.0
|
||||
hidden_length_ms = 0.0
|
||||
extract_latent_ms = 0.0
|
||||
if self.stage_limiter is None:
|
||||
h2d_start = time.perf_counter()
|
||||
input_values = input_values_cpu.to(self.device)
|
||||
attention_mask = attention_mask_cpu.to(self.device)
|
||||
if self.is_half:
|
||||
input_values = input_values.half()
|
||||
forward_start = time.perf_counter()
|
||||
h2d_ms = (time.perf_counter() - h2d_start) * 1000.0
|
||||
ssl_start = time.perf_counter()
|
||||
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
|
||||
ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0
|
||||
hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
|
||||
hidden_length_start = time.perf_counter()
|
||||
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
|
||||
hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0
|
||||
latent_start = time.perf_counter()
|
||||
codes = self.vits_model.extract_latent(hubert_feature)
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0
|
||||
else:
|
||||
with self.stage_limiter.enter() as limiter_stats:
|
||||
h2d_start = time.perf_counter()
|
||||
input_values = input_values_cpu.to(self.device)
|
||||
attention_mask = attention_mask_cpu.to(self.device)
|
||||
if self.is_half:
|
||||
input_values = input_values.half()
|
||||
forward_start = time.perf_counter()
|
||||
h2d_ms = (time.perf_counter() - h2d_start) * 1000.0
|
||||
ssl_start = time.perf_counter()
|
||||
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
|
||||
ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0
|
||||
hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
|
||||
hidden_length_start = time.perf_counter()
|
||||
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
|
||||
hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0
|
||||
latent_start = time.perf_counter()
|
||||
codes = self.vits_model.extract_latent(hubert_feature)
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0
|
||||
forward_ms = float(h2d_ms + ssl_forward_ms + hidden_length_ms + extract_latent_ms)
|
||||
|
||||
code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None))
|
||||
scatter_start = time.perf_counter()
|
||||
@ -308,6 +348,11 @@ class PrepareRefSemanticBatchWorker:
|
||||
0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0
|
||||
),
|
||||
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
|
||||
"prompt_semantic_pack_ms": float(pack_ms),
|
||||
"prompt_semantic_h2d_ms": float(h2d_ms),
|
||||
"prompt_semantic_ssl_forward_ms": float(ssl_forward_ms),
|
||||
"prompt_semantic_hidden_length_ms": float(hidden_length_ms),
|
||||
"prompt_semantic_extract_latent_ms": float(extract_latent_ms),
|
||||
"prompt_semantic_forward_ms": float(forward_ms),
|
||||
"prompt_semantic_scatter_ms": 0.0,
|
||||
"prompt_semantic_calls": 1.0,
|
||||
|
||||
@ -55,7 +55,7 @@ class T2SRequestState:
|
||||
all_phones: torch.LongTensor
|
||||
all_bert_features: torch.Tensor
|
||||
prompt_semantic: torch.LongTensor
|
||||
refer_spec: Tuple[torch.Tensor, Optional[torch.Tensor]]
|
||||
refer_spec: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]
|
||||
aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
|
||||
raw_audio: torch.Tensor
|
||||
raw_sr: int
|
||||
@ -188,7 +188,11 @@ def build_request_state_from_parts(
|
||||
ref_audio_bundle_ms = float(ref_audio_bundle.get("profile", {}).get("bundle_total_ms", 0.0))
|
||||
bundle_profile = ref_audio_bundle.get("profile", {})
|
||||
prompt_semantic = ref_audio_bundle["prompt_semantic"].long()
|
||||
spec_audio, audio_16k = ref_audio_bundle["refer_spec"]
|
||||
refer_spec_value = ref_audio_bundle.get("refer_spec")
|
||||
if refer_spec_value in [None, ()]:
|
||||
spec_audio, audio_16k = None, None
|
||||
else:
|
||||
spec_audio, audio_16k = refer_spec_value
|
||||
aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] = []
|
||||
for aux_ref_audio_path in list(getattr(spec, "aux_ref_audio_paths", []) or []):
|
||||
if aux_ref_audio_path in [None, ""]:
|
||||
@ -323,6 +327,11 @@ def build_request_state_from_parts(
|
||||
bundle_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
|
||||
"prompt_semantic_pack_ms": float(bundle_profile.get("prompt_semantic_pack_ms", 0.0)),
|
||||
"prompt_semantic_h2d_ms": float(bundle_profile.get("prompt_semantic_h2d_ms", 0.0)),
|
||||
"prompt_semantic_ssl_forward_ms": float(bundle_profile.get("prompt_semantic_ssl_forward_ms", 0.0)),
|
||||
"prompt_semantic_hidden_length_ms": float(bundle_profile.get("prompt_semantic_hidden_length_ms", 0.0)),
|
||||
"prompt_semantic_extract_latent_ms": float(bundle_profile.get("prompt_semantic_extract_latent_ms", 0.0)),
|
||||
"prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)),
|
||||
"prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
||||
"prompt_semantic_stage_slots": float(bundle_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||||
@ -331,6 +340,11 @@ def build_request_state_from_parts(
|
||||
"prompt_semantic_batch_samples": float(bundle_profile.get("prompt_semantic_batch_samples", 0.0)),
|
||||
"ref_spec_wait_ms": float(bundle_profile.get("ref_spec_wait_ms", 0.0)),
|
||||
"ref_spec_ms": ref_spec_ms,
|
||||
"ref_spec_to_device_ms": float(bundle_profile.get("ref_spec_to_device_ms", 0.0)),
|
||||
"ref_spec_main_resample_ms": float(bundle_profile.get("ref_spec_main_resample_ms", 0.0)),
|
||||
"ref_spec_norm_ms": float(bundle_profile.get("ref_spec_norm_ms", 0.0)),
|
||||
"ref_spec_spectrogram_ms": float(bundle_profile.get("ref_spec_spectrogram_ms", 0.0)),
|
||||
"ref_spec_post_resample_ms": float(bundle_profile.get("ref_spec_post_resample_ms", 0.0)),
|
||||
"ref_audio_bundle_ms": ref_audio_bundle_ms,
|
||||
"tensorize_ms": tensorize_ms,
|
||||
"total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0,
|
||||
@ -352,7 +366,7 @@ def build_request_state_from_parts(
|
||||
all_phones=all_phones,
|
||||
all_bert_features=all_bert_features,
|
||||
prompt_semantic=prompt_semantic,
|
||||
refer_spec=(spec_audio, audio_16k),
|
||||
refer_spec=(None if spec_audio is None else (spec_audio, audio_16k)),
|
||||
aux_refer_specs=aux_refer_specs,
|
||||
raw_audio=raw_audio,
|
||||
raw_sr=raw_sr,
|
||||
|
||||
@ -21,6 +21,14 @@ class EngineRegistryBridgeFacade:
|
||||
def engine_prepare_queue_owner(self):
|
||||
return self.owner.engine_prepare_queue_owner
|
||||
|
||||
@property
|
||||
def engine_prepare_text_queue_owner(self):
|
||||
return self.owner.engine_prepare_text_queue_owner
|
||||
|
||||
@property
|
||||
def engine_prepare_ref_spec_queue_owner(self):
|
||||
return self.owner.engine_prepare_ref_spec_queue_owner
|
||||
|
||||
@property
|
||||
def engine_finalize_queue_owner(self):
|
||||
return self.owner.engine_finalize_queue_owner
|
||||
@ -82,7 +90,33 @@ class EngineRegistryBridgeFacade:
|
||||
return self.request_registry.snapshot()
|
||||
|
||||
def _snapshot_engine_prepare_state(self) -> Dict[str, Any]:
|
||||
return self.engine_prepare_queue_owner.snapshot(max_request_ids=16)
|
||||
audio_snapshot = self.engine_prepare_queue_owner.snapshot(max_request_ids=16)
|
||||
text_snapshot = self.engine_prepare_text_queue_owner.snapshot(max_request_ids=16)
|
||||
ref_spec_snapshot = self.engine_prepare_ref_spec_queue_owner.snapshot(max_request_ids=16)
|
||||
return {
|
||||
"waiting_count": int(audio_snapshot.get("waiting_count", 0))
|
||||
+ int(text_snapshot.get("waiting_count", 0))
|
||||
+ int(ref_spec_snapshot.get("waiting_count", 0)),
|
||||
"audio_waiting_count": int(audio_snapshot.get("waiting_count", 0)),
|
||||
"text_waiting_count": int(text_snapshot.get("waiting_count", 0)),
|
||||
"ref_spec_waiting_count": int(ref_spec_snapshot.get("waiting_count", 0)),
|
||||
"audio_waiting_request_ids": list(audio_snapshot.get("waiting_request_ids", [])),
|
||||
"text_waiting_request_ids": list(text_snapshot.get("waiting_request_ids", [])),
|
||||
"ref_spec_waiting_request_ids": list(ref_spec_snapshot.get("waiting_request_ids", [])),
|
||||
"peak_waiting": int(
|
||||
max(
|
||||
int(audio_snapshot.get("peak_waiting", 0)),
|
||||
int(text_snapshot.get("peak_waiting", 0)),
|
||||
int(ref_spec_snapshot.get("peak_waiting", 0)),
|
||||
)
|
||||
),
|
||||
"total_submitted": int(audio_snapshot.get("total_submitted", 0)),
|
||||
"total_completed": int(audio_snapshot.get("total_completed", 0)),
|
||||
"text_total_submitted": int(text_snapshot.get("total_submitted", 0)),
|
||||
"text_total_completed": int(text_snapshot.get("total_completed", 0)),
|
||||
"ref_spec_total_submitted": int(ref_spec_snapshot.get("total_submitted", 0)),
|
||||
"ref_spec_total_completed": int(ref_spec_snapshot.get("total_completed", 0)),
|
||||
}
|
||||
|
||||
def _snapshot_engine_finalize_state(self) -> Dict[str, Any]:
|
||||
return self.engine_finalize_queue_owner.snapshot(max_request_ids=16)
|
||||
@ -107,6 +141,8 @@ class EngineRegistryBridgeFacade:
|
||||
|
||||
def _is_engine_drained(self) -> bool:
|
||||
prepare_empty = self.engine_prepare_queue_owner.is_drained()
|
||||
prepare_text_empty = self.engine_prepare_text_queue_owner.is_drained()
|
||||
prepare_ref_spec_empty = self.engine_prepare_ref_spec_queue_owner.is_drained()
|
||||
dispatch_empty = self.engine_dispatch_queue_owner.is_drained()
|
||||
finalize_empty = self.engine_finalize_queue_owner.is_drained()
|
||||
decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs()
|
||||
@ -114,6 +150,8 @@ class EngineRegistryBridgeFacade:
|
||||
worker_state = self.scheduler_worker.snapshot()
|
||||
return bool(
|
||||
prepare_empty
|
||||
and prepare_text_empty
|
||||
and prepare_ref_spec_empty
|
||||
and dispatch_empty
|
||||
and finalize_empty
|
||||
and decode_pending_empty
|
||||
|
||||
@ -110,6 +110,8 @@ class EngineCompositionBuilder:
|
||||
get_micro_batch_wait_s=owner.scheduler_worker.get_micro_batch_wait_s,
|
||||
)
|
||||
owner.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
|
||||
owner.engine_prepare_text_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
|
||||
owner.engine_prepare_ref_spec_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
|
||||
owner.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
|
||||
owner.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched")
|
||||
|
||||
@ -119,6 +121,8 @@ class EngineCompositionBuilder:
|
||||
tts=owner.tts,
|
||||
scheduler_worker=owner.scheduler_worker,
|
||||
prepare_queue_owner=owner.engine_prepare_queue_owner,
|
||||
prepare_text_queue_owner=owner.engine_prepare_text_queue_owner,
|
||||
prepare_ref_spec_queue_owner=owner.engine_prepare_ref_spec_queue_owner,
|
||||
finalize_queue_owner=owner.engine_finalize_queue_owner,
|
||||
dispatch_queue_owner=owner.engine_dispatch_queue_owner,
|
||||
decode_runtime_owner=owner.engine_decode_runtime_owner,
|
||||
|
||||
@ -129,7 +129,7 @@ class EnginePolicyArbiterController:
|
||||
self.state.total_ticks += 1
|
||||
if stage == "idle":
|
||||
self.state.total_idle_ticks += 1
|
||||
elif stage == "prepare":
|
||||
elif stage in {"prepare", "prepare_audio", "prepare_text", "prepare_ref_spec"}:
|
||||
self.state.total_prepare_dispatches += 1
|
||||
self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst)
|
||||
elif stage == "finalize":
|
||||
@ -282,7 +282,11 @@ class EnginePolicyArbiterController:
|
||||
request_registry = self.snapshot_request_registry()
|
||||
worker_state = self.get_worker_state()
|
||||
policy_snapshot = self.build_policy_snapshot(request_registry, worker_state)
|
||||
prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0))
|
||||
prepare_state = self.snapshot_prepare_state()
|
||||
prepare_waiting = int(prepare_state.get("waiting_count", 0))
|
||||
prepare_audio_waiting = int(prepare_state.get("audio_waiting_count", 0))
|
||||
prepare_text_waiting = int(prepare_state.get("text_waiting_count", 0))
|
||||
prepare_ref_spec_waiting = int(prepare_state.get("ref_spec_waiting_count", 0))
|
||||
finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0))
|
||||
decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0))
|
||||
decode_runtime_state = self.snapshot_decode_runtime_state()
|
||||
@ -291,6 +295,9 @@ class EnginePolicyArbiterController:
|
||||
worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0))
|
||||
worker_running_requests = int(decode_runtime_state.get("active_request_count", 0))
|
||||
prepare_age_ms = float(self.peek_queue_age_ms("prepare"))
|
||||
prepare_audio_age_ms = float(self.peek_queue_age_ms("prepare_audio"))
|
||||
prepare_text_age_ms = float(self.peek_queue_age_ms("prepare_text"))
|
||||
prepare_ref_spec_age_ms = float(self.peek_queue_age_ms("prepare_ref_spec"))
|
||||
finalize_age_ms = float(self.peek_queue_age_ms("finalize"))
|
||||
decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending"))
|
||||
decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0))
|
||||
@ -316,14 +323,31 @@ class EnginePolicyArbiterController:
|
||||
and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0)
|
||||
):
|
||||
return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state
|
||||
if (
|
||||
finalize_waiting > 0
|
||||
and prepare_ref_spec_waiting > 0
|
||||
and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0)
|
||||
):
|
||||
return "prepare_ref_spec", "finalize_waiting_for_ref_spec", policy_snapshot, worker_state
|
||||
if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0):
|
||||
return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
|
||||
if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms):
|
||||
return "finalize", "finalize_aging", policy_snapshot, worker_state
|
||||
if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0):
|
||||
return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
|
||||
if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms):
|
||||
return "prepare_text", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
|
||||
if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0 and prepare_text_waiting <= 0:
|
||||
return "prepare_ref_spec", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
|
||||
return "prepare_audio", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
|
||||
if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms):
|
||||
return "prepare", "prepare_aging", policy_snapshot, worker_state
|
||||
if prepare_text_waiting > 0 and prepare_text_age_ms >= max(prepare_audio_age_ms, prepare_age_ms - 1e-6):
|
||||
return "prepare_text", "prepare_aging", policy_snapshot, worker_state
|
||||
if (
|
||||
prepare_ref_spec_waiting > 0
|
||||
and prepare_ref_spec_age_ms >= max(prepare_audio_age_ms, prepare_text_age_ms, prepare_age_ms - 1e-6)
|
||||
):
|
||||
return "prepare_ref_spec", "prepare_aging", policy_snapshot, worker_state
|
||||
return "prepare_audio", "prepare_aging", policy_snapshot, worker_state
|
||||
if worker_decode_control_enabled and worker_decode_has_work and policy_allowed:
|
||||
return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state
|
||||
if decode_waiting > 0 and policy_allowed:
|
||||
@ -331,5 +355,9 @@ class EnginePolicyArbiterController:
|
||||
if finalize_waiting > 0:
|
||||
return "finalize", "finalize_fallback", policy_snapshot, worker_state
|
||||
if prepare_waiting > 0:
|
||||
return "prepare", "prepare_fallback", policy_snapshot, worker_state
|
||||
if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms):
|
||||
return "prepare_text", "prepare_fallback", policy_snapshot, worker_state
|
||||
if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0:
|
||||
return "prepare_ref_spec", "prepare_fallback", policy_snapshot, worker_state
|
||||
return "prepare_audio", "prepare_fallback", policy_snapshot, worker_state
|
||||
return "idle", "no_pending_work", policy_snapshot, worker_state
|
||||
|
||||
@ -324,8 +324,24 @@ class EngineGpuPrepareTask:
|
||||
done_future: asyncio.Future | None
|
||||
engine_request_id: str | None
|
||||
enqueue_time: float
|
||||
queue_wait_ms: float = 0.0
|
||||
phase: str = "audio"
|
||||
audio_enqueue_time: float = 0.0
|
||||
audio_start_time: float = 0.0
|
||||
audio_end_time: float = 0.0
|
||||
text_enqueue_time: float = 0.0
|
||||
text_start_time: float = 0.0
|
||||
text_end_time: float = 0.0
|
||||
ref_spec_enqueue_time: float = 0.0
|
||||
ref_spec_start_time: float = 0.0
|
||||
ref_spec_end_time: float = 0.0
|
||||
audio_queue_wait_ms: float = 0.0
|
||||
text_queue_wait_ms: float = 0.0
|
||||
ref_spec_queue_wait_ms: float = 0.0
|
||||
admission_wait_ms: float = 0.0
|
||||
phase_one: Dict[str, Any] | None = None
|
||||
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]] | None = None
|
||||
state_result: T2SRequestState | None = None
|
||||
cancelled: bool = False
|
||||
error: str | None = None
|
||||
|
||||
|
||||
|
||||
@ -14,6 +14,8 @@ class EngineStageOrchestrator:
|
||||
executor: EngineStageExecutor,
|
||||
scheduler_worker: UnifiedSchedulerWorker,
|
||||
prepare_queue_owner: EngineTaskQueueOwner,
|
||||
prepare_text_queue_owner: EngineTaskQueueOwner,
|
||||
prepare_ref_spec_queue_owner: EngineTaskQueueOwner,
|
||||
finalize_queue_owner: EngineTaskQueueOwner,
|
||||
dispatch_queue_owner: EngineTaskQueueOwner,
|
||||
decode_runtime_owner: EngineDecodeRuntimeOwner,
|
||||
@ -22,6 +24,8 @@ class EngineStageOrchestrator:
|
||||
self.executor = executor
|
||||
self.scheduler_worker = scheduler_worker
|
||||
self.prepare_queue_owner = prepare_queue_owner
|
||||
self.prepare_text_queue_owner = prepare_text_queue_owner
|
||||
self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner
|
||||
self.finalize_queue_owner = finalize_queue_owner
|
||||
self.dispatch_queue_owner = dispatch_queue_owner
|
||||
self.decode_runtime_owner = decode_runtime_owner
|
||||
@ -45,7 +49,17 @@ class EngineStageOrchestrator:
|
||||
|
||||
def peek_queue_age_ms(self, queue_name: str) -> float:
|
||||
if queue_name == "prepare":
|
||||
return max(
|
||||
self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time"),
|
||||
self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time"),
|
||||
self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time"),
|
||||
)
|
||||
if queue_name == "prepare_audio":
|
||||
return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time")
|
||||
if queue_name == "prepare_text":
|
||||
return self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time")
|
||||
if queue_name == "prepare_ref_spec":
|
||||
return self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time")
|
||||
if queue_name == "finalize":
|
||||
return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time")
|
||||
if queue_name == "decode_runtime_pending":
|
||||
@ -62,6 +76,10 @@ class EngineStageOrchestrator:
|
||||
return True
|
||||
if self.prepare_queue_owner.has_items():
|
||||
return True
|
||||
if self.prepare_text_queue_owner.has_items():
|
||||
return True
|
||||
if self.prepare_ref_spec_queue_owner.has_items():
|
||||
return True
|
||||
if self.finalize_queue_owner.has_items():
|
||||
return True
|
||||
return self.dispatch_queue_owner.has_items()
|
||||
@ -79,6 +97,12 @@ class EngineStageOrchestrator:
|
||||
executed = False
|
||||
if stage == "prepare":
|
||||
executed = self.executor.run_engine_prepare_once()
|
||||
elif stage == "prepare_audio":
|
||||
executed = self.executor.run_engine_prepare_audio_once()
|
||||
elif stage == "prepare_text":
|
||||
executed = self.executor.run_engine_prepare_text_once()
|
||||
elif stage == "prepare_ref_spec":
|
||||
executed = self.executor.run_engine_prepare_ref_spec_once()
|
||||
elif stage == "finalize":
|
||||
executed = self.executor.run_engine_finalize_once()
|
||||
elif stage == "decode_dispatch":
|
||||
|
||||
@ -24,6 +24,8 @@ class EngineStageCoordinator:
|
||||
tts: TTS,
|
||||
scheduler_worker: UnifiedSchedulerWorker,
|
||||
prepare_queue_owner: EngineTaskQueueOwner,
|
||||
prepare_text_queue_owner: EngineTaskQueueOwner,
|
||||
prepare_ref_spec_queue_owner: EngineTaskQueueOwner,
|
||||
finalize_queue_owner: EngineTaskQueueOwner,
|
||||
dispatch_queue_owner: EngineTaskQueueOwner,
|
||||
decode_runtime_owner: EngineDecodeRuntimeOwner,
|
||||
@ -45,6 +47,8 @@ class EngineStageCoordinator:
|
||||
tts=tts,
|
||||
scheduler_worker=scheduler_worker,
|
||||
prepare_queue_owner=prepare_queue_owner,
|
||||
prepare_text_queue_owner=prepare_text_queue_owner,
|
||||
prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner,
|
||||
finalize_queue_owner=finalize_queue_owner,
|
||||
dispatch_queue_owner=dispatch_queue_owner,
|
||||
decode_runtime_owner=decode_runtime_owner,
|
||||
@ -66,6 +70,8 @@ class EngineStageCoordinator:
|
||||
executor=self.executor,
|
||||
scheduler_worker=scheduler_worker,
|
||||
prepare_queue_owner=prepare_queue_owner,
|
||||
prepare_text_queue_owner=prepare_text_queue_owner,
|
||||
prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner,
|
||||
finalize_queue_owner=finalize_queue_owner,
|
||||
dispatch_queue_owner=dispatch_queue_owner,
|
||||
decode_runtime_owner=decode_runtime_owner,
|
||||
@ -144,6 +150,15 @@ class EngineStageCoordinator:
|
||||
def run_engine_prepare_once(self) -> bool:
|
||||
return self.executor.run_engine_prepare_once()
|
||||
|
||||
def run_engine_prepare_audio_once(self) -> bool:
|
||||
return self.executor.run_engine_prepare_audio_once()
|
||||
|
||||
def run_engine_prepare_text_once(self) -> bool:
|
||||
return self.executor.run_engine_prepare_text_once()
|
||||
|
||||
def run_engine_prepare_ref_spec_once(self) -> bool:
|
||||
return self.executor.run_engine_prepare_ref_spec_once()
|
||||
|
||||
def run_engine_finalize_once(self) -> bool:
|
||||
return self.executor.run_engine_finalize_once()
|
||||
|
||||
|
||||
@ -24,6 +24,10 @@ class EngineDispatchStageMixin:
|
||||
engine_request_id: str | None,
|
||||
timeout_sec: float | None,
|
||||
) -> EngineDispatchTask:
|
||||
if float(state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0:
|
||||
error = RuntimeError("ref_spec async stage failed before dispatch")
|
||||
self.fail_request_state(engine_request_id or state.request_id, str(error))
|
||||
raise error
|
||||
task = EngineDispatchTask(
|
||||
request_id=state.request_id,
|
||||
state=state,
|
||||
|
||||
@ -31,6 +31,8 @@ class EngineStageExecutor(
|
||||
tts: TTS,
|
||||
scheduler_worker: UnifiedSchedulerWorker,
|
||||
prepare_queue_owner: EngineTaskQueueOwner,
|
||||
prepare_text_queue_owner: EngineTaskQueueOwner,
|
||||
prepare_ref_spec_queue_owner: EngineTaskQueueOwner,
|
||||
finalize_queue_owner: EngineTaskQueueOwner,
|
||||
dispatch_queue_owner: EngineTaskQueueOwner,
|
||||
decode_runtime_owner: EngineDecodeRuntimeOwner,
|
||||
@ -51,6 +53,8 @@ class EngineStageExecutor(
|
||||
self.tts = tts
|
||||
self.scheduler_worker = scheduler_worker
|
||||
self.prepare_queue_owner = prepare_queue_owner
|
||||
self.prepare_text_queue_owner = prepare_text_queue_owner
|
||||
self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner
|
||||
self.finalize_queue_owner = finalize_queue_owner
|
||||
self.dispatch_queue_owner = dispatch_queue_owner
|
||||
self.decode_runtime_owner = decode_runtime_owner
|
||||
|
||||
@ -39,10 +39,37 @@ class EngineFinalizeStageMixin:
|
||||
tasks = self.take_engine_finalize_batch_nonblocking()
|
||||
if not tasks:
|
||||
return False
|
||||
self.scheduler_worker.begin_finalize_execution(len(tasks))
|
||||
ready_tasks: List[SchedulerFinalizeTask] = []
|
||||
failed_tasks: List[SchedulerFinalizeTask] = []
|
||||
deferred_tasks: List[SchedulerFinalizeTask] = []
|
||||
for task in tasks:
|
||||
job = self.get_engine_job(task.request_id)
|
||||
if job is None:
|
||||
continue
|
||||
if float(job.state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0:
|
||||
failed_tasks.append(task)
|
||||
continue
|
||||
if job.state.refer_spec is None:
|
||||
deferred_tasks.append(task)
|
||||
self.merge_request_state_profile(
|
||||
job.engine_request_id or job.request_id,
|
||||
{
|
||||
"engine_finalize_ref_spec_blocked": 1.0,
|
||||
},
|
||||
)
|
||||
continue
|
||||
ready_tasks.append(task)
|
||||
if deferred_tasks:
|
||||
self.finalize_queue_owner.enqueue_many(deferred_tasks)
|
||||
if failed_tasks:
|
||||
self.fail_engine_jobs([task.request_id for task in failed_tasks], "ref_spec async stage failed")
|
||||
if not ready_tasks:
|
||||
self.finalize_queue_owner.mark_completed(len(failed_tasks), notify=True)
|
||||
return False
|
||||
self.scheduler_worker.begin_finalize_execution(len(ready_tasks))
|
||||
try:
|
||||
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
|
||||
for task in tasks:
|
||||
for task in ready_tasks:
|
||||
job = self.get_engine_job(task.request_id)
|
||||
if job is None:
|
||||
continue
|
||||
@ -50,7 +77,7 @@ class EngineFinalizeStageMixin:
|
||||
if not jobs_and_items:
|
||||
return False
|
||||
now = time.perf_counter()
|
||||
for task in tasks:
|
||||
for task in ready_tasks:
|
||||
job = self.get_engine_job(task.request_id)
|
||||
if job is not None:
|
||||
job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0)
|
||||
@ -69,8 +96,8 @@ class EngineFinalizeStageMixin:
|
||||
for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results):
|
||||
self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
|
||||
except Exception as exc:
|
||||
self.fail_engine_jobs([task.request_id for task in tasks], str(exc))
|
||||
self.fail_engine_jobs([task.request_id for task in ready_tasks], str(exc))
|
||||
finally:
|
||||
self.scheduler_worker.end_finalize_execution(len(tasks))
|
||||
self.finalize_queue_owner.mark_completed(len(tasks), notify=True)
|
||||
self.scheduler_worker.end_finalize_execution(len(ready_tasks))
|
||||
self.finalize_queue_owner.mark_completed(len(ready_tasks) + len(failed_tasks), notify=True)
|
||||
return True
|
||||
|
||||
@ -10,6 +10,13 @@ from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineGpuPrepare
|
||||
|
||||
|
||||
class EnginePrepareStageMixin:
|
||||
def _prepare_waiting_total(self) -> int:
|
||||
return (
|
||||
int(self.prepare_queue_owner.waiting_count())
|
||||
+ int(self.prepare_text_queue_owner.waiting_count())
|
||||
+ int(self.prepare_ref_spec_queue_owner.waiting_count())
|
||||
)
|
||||
|
||||
async def _wait_prepare_queue_admission(self) -> float:
|
||||
soft_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_SOFT_MAX", "0")))
|
||||
if soft_max <= 0:
|
||||
@ -19,7 +26,7 @@ class EnginePrepareStageMixin:
|
||||
float(max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_ADMISSION_POLL_MS", "1")))) / 1000.0,
|
||||
)
|
||||
wait_start = time.perf_counter()
|
||||
while self.prepare_queue_owner.waiting_count() >= soft_max:
|
||||
while self._prepare_waiting_total() >= soft_max:
|
||||
await asyncio.sleep(poll_s)
|
||||
return max(0.0, (time.perf_counter() - wait_start) * 1000.0)
|
||||
|
||||
@ -53,44 +60,247 @@ class EnginePrepareStageMixin:
|
||||
done_future=done_future,
|
||||
engine_request_id=engine_request_id or spec.request_id,
|
||||
enqueue_time=time.perf_counter(),
|
||||
phase="audio",
|
||||
audio_enqueue_time=time.perf_counter(),
|
||||
admission_wait_ms=float(prepare_queue_admission_wait_ms),
|
||||
)
|
||||
self.prepare_queue_owner.enqueue(task)
|
||||
self.notify_arbiter()
|
||||
return await done_future
|
||||
|
||||
def run_engine_prepare_once(self) -> bool:
|
||||
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
|
||||
tasks = self.prepare_queue_owner.pop_left_many(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))
|
||||
def _should_chain_prepare_text_after_audio(self) -> bool:
|
||||
if str(os.environ.get("GPTSOVITS_ENGINE_PREPARE_CHAIN_TEXT", "1")).strip().lower() in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
if self.finalize_queue_owner.has_items() or self.dispatch_queue_owner.has_items():
|
||||
return False
|
||||
decode_runtime_state = self.snapshot_engine_decode_runtime_state()
|
||||
if bool(decode_runtime_state.get("has_work", False)):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _maybe_apply_ref_spec_to_state(self, task: EngineGpuPrepareTask) -> None:
|
||||
if task.state_result is None or task.ref_spec_result is None:
|
||||
return
|
||||
self.scheduler_worker.apply_ref_spec_result_to_state(task.state_result, task.ref_spec_result)
|
||||
if task.engine_request_id not in [None, ""]:
|
||||
self.merge_request_state_profile(
|
||||
str(task.engine_request_id),
|
||||
{
|
||||
"engine_prepare_ref_spec_queue_wait_ms": float(task.ref_spec_queue_wait_ms),
|
||||
"ref_spec_wait_ms": float(task.ref_spec_result[1].get("ref_spec_wait_ms", 0.0)),
|
||||
"ref_spec_ms": float(task.ref_spec_result[1].get("ref_spec_ms", 0.0)),
|
||||
"ref_spec_to_device_ms": float(task.ref_spec_result[1].get("ref_spec_to_device_ms", 0.0)),
|
||||
"ref_spec_main_resample_ms": float(task.ref_spec_result[1].get("ref_spec_main_resample_ms", 0.0)),
|
||||
"ref_spec_norm_ms": float(task.ref_spec_result[1].get("ref_spec_norm_ms", 0.0)),
|
||||
"ref_spec_spectrogram_ms": float(task.ref_spec_result[1].get("ref_spec_spectrogram_ms", 0.0)),
|
||||
"ref_spec_post_resample_ms": float(task.ref_spec_result[1].get("ref_spec_post_resample_ms", 0.0)),
|
||||
},
|
||||
)
|
||||
|
||||
def _mark_ref_spec_async_failed(
|
||||
self,
|
||||
task: EngineGpuPrepareTask,
|
||||
error: Exception,
|
||||
*,
|
||||
queue_wait_ms: float,
|
||||
) -> None:
|
||||
task.error = str(error)
|
||||
task.cancelled = True
|
||||
if task.state_result is not None:
|
||||
task.state_result.prepare_profile["ref_spec_async_failed"] = 1.0
|
||||
task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms)
|
||||
if task.engine_request_id not in [None, ""]:
|
||||
self.merge_request_state_profile(
|
||||
str(task.engine_request_id),
|
||||
{
|
||||
"ref_spec_async_failed": 1.0,
|
||||
"engine_prepare_ref_spec_queue_wait_ms": float(queue_wait_ms),
|
||||
},
|
||||
)
|
||||
self.fail_request_state(task.engine_request_id or task.request_id, str(error))
|
||||
self.fail_engine_jobs([task.request_id], str(error))
|
||||
self.notify_arbiter()
|
||||
|
||||
def _run_engine_prepare_audio_once(self, batch_max_items: int) -> bool:
|
||||
tasks = self.prepare_queue_owner.pop_left_many(batch_max_items)
|
||||
if not tasks:
|
||||
return False
|
||||
now = time.perf_counter()
|
||||
queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks]
|
||||
batch_results = asyncio.run(
|
||||
self.scheduler_worker.prepare_gpu_stages_profiled_async([task.cpu_stage for task in tasks])
|
||||
)
|
||||
for task in tasks:
|
||||
task.audio_start_time = float(now)
|
||||
batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_audio_phases_async([task.cpu_stage for task in tasks]))
|
||||
completed_count = 0
|
||||
for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results):
|
||||
task.audio_end_time = time.perf_counter()
|
||||
if isinstance(result, Exception):
|
||||
task.error = str(result)
|
||||
self.fail_request_state(task.engine_request_id or task.request_id, str(result))
|
||||
self._notify_prepare_error(task, result)
|
||||
completed_count += 1
|
||||
continue
|
||||
state, prepare_exec_started_at, prepare_exec_finished_at = result
|
||||
state.prepare_profile["engine_prepare_queue_admission_wait_ms"] = float(task.admission_wait_ms)
|
||||
state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms)
|
||||
task.audio_queue_wait_ms = float(queue_wait_ms)
|
||||
task.phase_one = result
|
||||
task.phase = "text"
|
||||
task.enqueue_time = time.perf_counter()
|
||||
task.text_enqueue_time = float(task.enqueue_time)
|
||||
task.ref_spec_enqueue_time = float(task.enqueue_time)
|
||||
self.prepare_text_queue_owner.enqueue(task)
|
||||
self.prepare_ref_spec_queue_owner.enqueue(task)
|
||||
if task.engine_request_id not in [None, ""]:
|
||||
self.merge_request_state_profile(
|
||||
str(task.engine_request_id),
|
||||
{
|
||||
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
|
||||
"engine_prepare_audio_queue_wait_ms": float(queue_wait_ms),
|
||||
"engine_prepare_audio_batch_size": float(len(tasks)),
|
||||
"engine_prepare_audio_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)),
|
||||
"engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time),
|
||||
"engine_prepare_audio_start_ts": float(task.audio_start_time),
|
||||
"engine_prepare_audio_end_ts": float(task.audio_end_time),
|
||||
"engine_prepare_text_enqueue_ts": float(task.text_enqueue_time),
|
||||
"engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time),
|
||||
},
|
||||
)
|
||||
completed_count += 1
|
||||
self.prepare_queue_owner.mark_completed(completed_count)
|
||||
if completed_count > 0 and self._should_chain_prepare_text_after_audio():
|
||||
self._run_engine_prepare_text_once(min(batch_max_items, completed_count))
|
||||
return True
|
||||
if completed_count > 0:
|
||||
self.notify_arbiter()
|
||||
return True
|
||||
|
||||
def _run_engine_prepare_text_once(self, batch_max_items: int) -> bool:
|
||||
tasks = self.prepare_text_queue_owner.pop_left_many(batch_max_items)
|
||||
if not tasks:
|
||||
return False
|
||||
now = time.perf_counter()
|
||||
queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks]
|
||||
for task in tasks:
|
||||
task.text_start_time = float(now)
|
||||
items = [(task.cpu_stage, task.phase_one) for task in tasks if task.phase_one is not None]
|
||||
batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_text_phases_async(items))
|
||||
completed_count = 0
|
||||
for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results):
|
||||
task.text_end_time = time.perf_counter()
|
||||
if isinstance(result, Exception):
|
||||
task.error = str(result)
|
||||
task.cancelled = True
|
||||
self.fail_request_state(task.engine_request_id or task.request_id, str(result))
|
||||
self._notify_prepare_error(task, result)
|
||||
completed_count += 1
|
||||
continue
|
||||
task.text_queue_wait_ms = float(queue_wait_ms)
|
||||
state, prepare_exec_started_at, prepare_exec_finished_at = self.scheduler_worker.build_gpu_prepare_result_from_phases(
|
||||
task.cpu_stage,
|
||||
task.phase_one or {},
|
||||
result,
|
||||
extra_profile={
|
||||
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
|
||||
"engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms),
|
||||
"engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms),
|
||||
"engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms),
|
||||
"engine_prepare_audio_batch_size": float(len(tasks)),
|
||||
"engine_prepare_text_batch_size": float(len(tasks)),
|
||||
"engine_prepare_audio_phase_mode": 2.0,
|
||||
"engine_prepare_audio_phase_wall_ms": float((task.phase_one or {}).get("phase_wall_ms", 0.0)),
|
||||
"engine_prepare_text_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)),
|
||||
"engine_prepare_text_phase_batch_size": float(len(tasks)),
|
||||
"engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time),
|
||||
"engine_prepare_audio_start_ts": float(task.audio_start_time),
|
||||
"engine_prepare_audio_end_ts": float(task.audio_end_time),
|
||||
"engine_prepare_text_enqueue_ts": float(task.text_enqueue_time),
|
||||
"engine_prepare_text_start_ts": float(task.text_start_time),
|
||||
"engine_prepare_text_end_ts": float(task.text_end_time),
|
||||
"engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time),
|
||||
},
|
||||
)
|
||||
task.state_result = state
|
||||
self._maybe_apply_ref_spec_to_state(task)
|
||||
state.prepare_profile["engine_gpu_prepare_batch_size"] = float(len(tasks))
|
||||
if task.engine_request_id not in [None, ""]:
|
||||
self.merge_request_state_profile(
|
||||
str(task.engine_request_id),
|
||||
{
|
||||
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
|
||||
"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms),
|
||||
"engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms),
|
||||
"engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms),
|
||||
"engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms),
|
||||
"engine_gpu_prepare_batch_size": float(len(tasks)),
|
||||
},
|
||||
)
|
||||
self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at))
|
||||
completed_count += 1
|
||||
self.prepare_queue_owner.mark_completed(completed_count)
|
||||
self.prepare_text_queue_owner.mark_completed(completed_count)
|
||||
return True
|
||||
|
||||
def _run_engine_prepare_ref_spec_once(self, batch_max_items: int) -> bool:
|
||||
tasks = self.prepare_ref_spec_queue_owner.pop_left_many(batch_max_items)
|
||||
if not tasks:
|
||||
return False
|
||||
now = time.perf_counter()
|
||||
runnable_tasks: list[EngineGpuPrepareTask] = []
|
||||
queue_wait_ms_list: list[float] = []
|
||||
completed_count = 0
|
||||
for task in tasks:
|
||||
if task.cancelled or task.phase_one is None:
|
||||
completed_count += 1
|
||||
continue
|
||||
task.ref_spec_start_time = float(now)
|
||||
runnable_tasks.append(task)
|
||||
queue_wait_ms_list.append(max(0.0, (now - task.ref_spec_enqueue_time) * 1000.0))
|
||||
if not runnable_tasks:
|
||||
self.prepare_ref_spec_queue_owner.mark_completed(completed_count)
|
||||
return True
|
||||
batch_results = asyncio.run(
|
||||
self.scheduler_worker.prepare_ref_spec_stages_async([task.phase_one or {} for task in runnable_tasks])
|
||||
)
|
||||
for task, queue_wait_ms, result in zip(runnable_tasks, queue_wait_ms_list, batch_results):
|
||||
task.ref_spec_end_time = time.perf_counter()
|
||||
task.ref_spec_queue_wait_ms = float(queue_wait_ms)
|
||||
if isinstance(result, Exception):
|
||||
self._mark_ref_spec_async_failed(task, result, queue_wait_ms=float(queue_wait_ms))
|
||||
completed_count += 1
|
||||
continue
|
||||
task.ref_spec_result = result
|
||||
self._maybe_apply_ref_spec_to_state(task)
|
||||
if task.state_result is not None:
|
||||
task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms)
|
||||
task.state_result.prepare_profile["engine_prepare_ref_spec_enqueue_ts"] = float(task.ref_spec_enqueue_time)
|
||||
task.state_result.prepare_profile["engine_prepare_ref_spec_start_ts"] = float(task.ref_spec_start_time)
|
||||
task.state_result.prepare_profile["engine_prepare_ref_spec_end_ts"] = float(task.ref_spec_end_time)
|
||||
completed_count += 1
|
||||
self.prepare_ref_spec_queue_owner.mark_completed(completed_count)
|
||||
return True
|
||||
|
||||
def run_engine_prepare_once(self) -> bool:
|
||||
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
|
||||
batch_max_items = int(prepare_batch_policy.get("prepare_batch_max_items", 1))
|
||||
audio_age_ms = self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time")
|
||||
text_age_ms = self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time")
|
||||
if self.prepare_text_queue_owner.has_items() and (
|
||||
not self.prepare_queue_owner.has_items() or text_age_ms >= audio_age_ms
|
||||
):
|
||||
return self._run_engine_prepare_text_once(batch_max_items)
|
||||
if self.prepare_queue_owner.has_items():
|
||||
return self._run_engine_prepare_audio_once(batch_max_items)
|
||||
if self.prepare_ref_spec_queue_owner.has_items():
|
||||
return self._run_engine_prepare_ref_spec_once(batch_max_items)
|
||||
if self.prepare_text_queue_owner.has_items():
|
||||
return self._run_engine_prepare_text_once(batch_max_items)
|
||||
if self.prepare_ref_spec_queue_owner.has_items():
|
||||
return self._run_engine_prepare_ref_spec_once(batch_max_items)
|
||||
return False
|
||||
|
||||
def run_engine_prepare_audio_once(self) -> bool:
|
||||
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
|
||||
return self._run_engine_prepare_audio_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))
|
||||
|
||||
def run_engine_prepare_text_once(self) -> bool:
|
||||
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
|
||||
return self._run_engine_prepare_text_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))
|
||||
|
||||
def run_engine_prepare_ref_spec_once(self) -> bool:
|
||||
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
|
||||
return self._run_engine_prepare_ref_spec_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))
|
||||
|
||||
@ -151,7 +151,9 @@ class WorkerFinalizeExecutor:
|
||||
|
||||
@staticmethod
|
||||
def _collect_job_refer_specs(job: SchedulerPendingJob) -> List[tuple]:
|
||||
refer_specs = [job.state.refer_spec]
|
||||
refer_specs = []
|
||||
if job.state.refer_spec is not None:
|
||||
refer_specs.append(job.state.refer_spec)
|
||||
refer_specs.extend(list(getattr(job.state, "aux_refer_specs", []) or []))
|
||||
return refer_specs
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Callable, Dict, List
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
|
||||
from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage
|
||||
@ -81,11 +81,60 @@ class WorkerPrepareExecutor:
|
||||
cpu_stages: List[PreparedCpuStage],
|
||||
) -> List[tuple[T2SRequestState, float, float] | Exception]:
|
||||
try:
|
||||
return list(
|
||||
await asyncio.gather(
|
||||
*[self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) for cpu_stage in cpu_stages],
|
||||
return_exceptions=True,
|
||||
)
|
||||
return await self.coordinator.prepare_gpu_stages_profiled_async(cpu_stages)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
|
||||
async def prepare_gpu_audio_phases_async(
|
||||
self,
|
||||
cpu_stages: List[PreparedCpuStage],
|
||||
) -> List[Dict[str, Any] | Exception]:
|
||||
try:
|
||||
return await self.coordinator.prepare_gpu_audio_phases_async(cpu_stages)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
|
||||
async def prepare_gpu_text_phases_async(
|
||||
self,
|
||||
items: List[tuple[PreparedCpuStage, Dict[str, Any]]],
|
||||
) -> List[Dict[str, Any] | Exception]:
|
||||
try:
|
||||
return await self.coordinator.prepare_gpu_text_phases_async(items)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
|
||||
def build_gpu_prepare_result_from_phases(
|
||||
self,
|
||||
cpu_stage: PreparedCpuStage,
|
||||
phase_one: Dict[str, Any],
|
||||
phase_two: Dict[str, Any],
|
||||
extra_profile: Dict[str, float] | None = None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
try:
|
||||
return self.coordinator.build_gpu_prepare_result_from_phases(
|
||||
cpu_stage,
|
||||
phase_one,
|
||||
phase_two,
|
||||
extra_profile=extra_profile,
|
||||
)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
|
||||
async def prepare_ref_spec_stages_async(
|
||||
self,
|
||||
phase_ones: List[Dict[str, Any]],
|
||||
) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]:
|
||||
try:
|
||||
return await self.coordinator.prepare_ref_spec_stages_async(phase_ones)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
|
||||
def apply_ref_spec_result_to_state(
|
||||
self,
|
||||
state: T2SRequestState,
|
||||
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]],
|
||||
) -> None:
|
||||
try:
|
||||
self.coordinator.apply_ref_spec_result_to_state(state, ref_spec_result)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
|
||||
@ -267,3 +267,42 @@ class WorkerSubmitLifecycleMixin:
|
||||
cpu_stages: List[PreparedCpuStage],
|
||||
) -> List[tuple[T2SRequestState, float, float] | Exception]:
|
||||
return await self.prepare_executor.prepare_gpu_stages_profiled_async(cpu_stages)
|
||||
|
||||
async def prepare_gpu_audio_phases_async(
|
||||
self,
|
||||
cpu_stages: List[PreparedCpuStage],
|
||||
) -> List[Dict[str, Any] | Exception]:
|
||||
return await self.prepare_executor.prepare_gpu_audio_phases_async(cpu_stages)
|
||||
|
||||
async def prepare_gpu_text_phases_async(
|
||||
self,
|
||||
items: List[tuple[PreparedCpuStage, Dict[str, Any]]],
|
||||
) -> List[Dict[str, Any] | Exception]:
|
||||
return await self.prepare_executor.prepare_gpu_text_phases_async(items)
|
||||
|
||||
def build_gpu_prepare_result_from_phases(
|
||||
self,
|
||||
cpu_stage: PreparedCpuStage,
|
||||
phase_one: Dict[str, Any],
|
||||
phase_two: Dict[str, Any],
|
||||
extra_profile: Dict[str, float] | None = None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
return self.prepare_executor.build_gpu_prepare_result_from_phases(
|
||||
cpu_stage,
|
||||
phase_one,
|
||||
phase_two,
|
||||
extra_profile=extra_profile,
|
||||
)
|
||||
|
||||
async def prepare_ref_spec_stages_async(
|
||||
self,
|
||||
phase_ones: List[Dict[str, Any]],
|
||||
) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]:
|
||||
return await self.prepare_executor.prepare_ref_spec_stages_async(phase_ones)
|
||||
|
||||
def apply_ref_spec_result_to_state(
|
||||
self,
|
||||
state: T2SRequestState,
|
||||
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]],
|
||||
) -> None:
|
||||
self.prepare_executor.apply_ref_spec_result_to_state(state, ref_spec_result)
|
||||
|
||||
@ -244,6 +244,16 @@ class G2PWRuntimeWrapper:
|
||||
)
|
||||
self.batch_worker.start()
|
||||
|
||||
def _sync_runtime_env_overrides(self) -> None:
|
||||
os.environ["G2PW_ENABLE_CUDA_GRAPH"] = "1" if self.enable_cuda_graph else "0"
|
||||
os.environ["G2PW_ENABLE_PROFILE"] = "1" if self.enable_profiling else "0"
|
||||
os.environ["G2PW_DUMP_GRAPH_CACHE_STATS"] = "1" if self.dump_graph_cache_stats else "0"
|
||||
os.environ["G2PW_FULL_GRAPH_CACHE_LIMIT"] = str(int(self.full_graph_cache_limit))
|
||||
os.environ["G2PW_TAIL_GRAPH_CACHE_LIMIT"] = str(int(self.tail_graph_cache_limit))
|
||||
os.environ["G2PW_ALLOW_TENSOR_CORES"] = "1" if self.allow_tensor_cores else "0"
|
||||
os.environ["G2PW_USE_CUBLASLT_BIAS_EPILOGUE"] = "1" if self.use_cublaslt_bias_epilogue else "0"
|
||||
os.environ["G2PW_GEMM_PRECISION"] = {0: "fp32", 1: "fp16", 2: "bf16"}.get(int(self.gemm_precision), "fp32")
|
||||
|
||||
def _destroy_handle(self) -> None:
|
||||
if self.handle:
|
||||
self.lib.g2pw_runtime_destroy(self.handle)
|
||||
@ -268,6 +278,7 @@ class G2PWRuntimeWrapper:
|
||||
return "" if not message else message.decode("utf-8", errors="replace")
|
||||
|
||||
def _create_handle(self, batch_size: int, seq_len: int) -> None:
|
||||
self._sync_runtime_env_overrides()
|
||||
new_handle = self.lib.g2pw_runtime_create(
|
||||
str(self.manifest_path).encode("utf-8"),
|
||||
str(self.weights_path).encode("utf-8"),
|
||||
@ -518,6 +529,10 @@ class G2PWRuntimeWrapper:
|
||||
return {
|
||||
"shard_index": int(self.shard_index),
|
||||
"enabled": bool(self.batch_enabled),
|
||||
"enable_cuda_graph": bool(self.enable_cuda_graph),
|
||||
"enable_profiling": bool(self.enable_profiling),
|
||||
"full_graph_cache_limit": int(self.full_graph_cache_limit),
|
||||
"tail_graph_cache_limit": int(self.tail_graph_cache_limit),
|
||||
"window_ms": float(self.batch_window_s * 1000.0),
|
||||
"max_requests": int(self.batch_max_requests),
|
||||
"max_rows": int(self.batch_max_rows),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user