From 5cf68a91d3b24299bf05edeec3907e46c4c6c0b5 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Thu, 12 Mar 2026 23:03:33 +0800 Subject: [PATCH] Add g2pw submodule and enhance TTS processing with AsyncStageGate Introduce a new submodule for g2pw and implement AsyncStageGate in PrepareCoordinator to manage concurrent task inflight limits. Update PrepareTextCpuWorker and PrepareRefSemanticBatchWorker to support asynchronous task submission and completion notifications. Enhance profiling capabilities in TTS to track g2pw processing times, improving overall performance and maintainability of the TTS system. --- .gitmodules | 3 + GPT_SoVITS/TTS_infer_pack/TTS.py | 172 ++++++- GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py | 9 + .../TTS_infer_pack/prepare_coordinator.py | 423 ++++++++++++++---- .../prepare_ref_semantic_batch_worker.py | 42 +- .../TTS_infer_pack/prepare_text_cpu_worker.py | 215 +++++++++ GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 122 ++++- .../unified_engine_api_profile.py | 1 + .../unified_engine_component_runtime.py | 11 + .../unified_engine_stage_prepare.py | 59 ++- .../unified_engine_worker_prepare.py | 20 + .../unified_engine_worker_submit.py | 9 + third_party/g2pw-cu | 1 + 13 files changed, 965 insertions(+), 122 deletions(-) create mode 100644 .gitmodules create mode 100644 GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py create mode 160000 third_party/g2pw-cu diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..570e9d7b --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/g2pw-cu"] + path = third_party/g2pw-cu + url = https://github.com/baicai-1145/g2pw-cu.git diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index bd811d8a..92c829a1 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1,4 +1,5 @@ import gc +import asyncio import concurrent.futures import math import os @@ -42,6 +43,7 @@ from TTS_infer_pack.prepare_ref_semantic_batch_worker import ( PrepareRefSemanticBatchWorker, prepare_prompt_semantic_wav16k, ) +from TTS_infer_pack.prepare_text_cpu_worker import PrepareTextCpuWorker from sv import SV resample_transform_dict = {} @@ -454,18 +456,12 @@ class TTS: self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4"))) self.prepare_bert_batch_worker = None self.prepare_ref_semantic_batch_worker = None + self.prepare_text_cpu_worker = None self.prepare_text_cpu_workers = max( 0, int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", "0")), ) - self.prepare_text_cpu_executor = ( - concurrent.futures.ThreadPoolExecutor( - max_workers=self.prepare_text_cpu_workers, - thread_name_prefix="prepare-text-cpu", - ) - if self.prepare_text_cpu_workers > 0 - else None - ) + self.prepare_text_cpu_executor = None self._init_models() self.refresh_runtime_components() @@ -488,6 +484,7 @@ class TTS: def refresh_runtime_components(self): self.prepare_bert_batch_worker = None self.prepare_ref_semantic_batch_worker = None + self.prepare_text_cpu_worker = None if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0": self.prepare_bert_batch_worker = PrepareBertBatchWorker( bert_model=self.bert_model, @@ -535,6 +532,92 @@ class TTS: bert_stage_limiter=self.prepare_bert_stage_limiter, bert_batch_worker=self.prepare_bert_batch_worker, ) + if self.prepare_text_cpu_workers > 0: + self.prepare_text_cpu_worker = PrepareTextCpuWorker( + process_fn=lambda text, language: self.text_preprocessor.preprocess_text_segments( + text, + language, + self.configs.version, + ), + worker_count=self.prepare_text_cpu_workers, + max_pending_tasks=int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_PENDING_TASKS", "0")), + admission_poll_ms=int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_ADMISSION_POLL_MS", "1")), + admission_controller=self._build_text_cpu_admission_state, + ) + + @staticmethod + def _safe_queue_qsize(executor) -> int | None: + if executor is None: + return None + queue = getattr(executor, "_work_queue", None) + if queue is None or not hasattr(queue, "qsize"): + return None + try: + return int(queue.qsize()) + except Exception: + return None + + def snapshot_prepare_runtime_components(self) -> dict: + return { + "text_cpu": { + "workers": int(self.prepare_text_cpu_workers), + "queue_size": self._safe_queue_qsize(self.prepare_text_cpu_executor), + "enabled": bool(self.prepare_text_cpu_worker is not None or self.prepare_text_cpu_executor is not None), + "worker": ( + None if self.prepare_text_cpu_worker is None else dict(self.prepare_text_cpu_worker.snapshot()) + ), + "admission": self._build_text_cpu_admission_state(), + }, + "bert": { + "stage_limiter": dict(self.prepare_bert_stage_limiter.snapshot()), + "batch_worker": ( + None if self.prepare_bert_batch_worker is None else dict(self.prepare_bert_batch_worker.snapshot()) + ), + "batching_enabled": bool(self.prepare_bert_batch_worker is not None), + }, + "ref_semantic": { + "stage_limiter": dict(self.prepare_ref_audio_stage_limiter.snapshot()), + "batch_worker": ( + None + if self.prepare_ref_semantic_batch_worker is None + else dict(self.prepare_ref_semantic_batch_worker.snapshot()) + ), + "batching_enabled": bool(self.prepare_ref_semantic_batch_worker is not None), + }, + "text_preprocessor": ( + None if self.text_preprocessor is None or not hasattr(self.text_preprocessor, "snapshot") else self.text_preprocessor.snapshot() + ), + } + + def _build_text_cpu_admission_state(self) -> dict: + bert_pending_soft_max = max( + 0, + int( + os.environ.get( + "GPTSOVITS_PREPARE_TEXT_CPU_BERT_PENDING_SOFT_MAX", + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_PENDING_THRESHOLD", "32"), + ) + ), + ) + if self.prepare_bert_batch_worker is None or bert_pending_soft_max <= 0: + return { + "blocked": False, + "reason": "", + "bert_pending": 0, + "bert_active_batch_size": 0, + "bert_pending_soft_max": int(bert_pending_soft_max), + } + bert_state = dict(self.prepare_bert_batch_worker.snapshot()) + bert_pending = int(bert_state.get("pending", 0)) + bert_active_batch_size = int(bert_state.get("active_batch_size", 0)) + blocked = bert_pending >= bert_pending_soft_max + return { + "blocked": bool(blocked), + "reason": ("bert_pending" if blocked else ""), + "bert_pending": int(bert_pending), + "bert_active_batch_size": int(bert_active_batch_size), + "bert_pending_soft_max": int(bert_pending_soft_max), + } def _init_models( self, @@ -1040,6 +1123,79 @@ class TTS: }, } + async def extract_ref_audio_bundle_async(self, ref_audio_path: str): + if self.prepare_ref_semantic_batch_worker is None: + return await asyncio.to_thread(self.extract_ref_audio_bundle, ref_audio_path) + + load_start = time.perf_counter() + raw_audio, raw_sr = await asyncio.to_thread(self._load_ref_audio_raw, ref_audio_path) + load_ms = (time.perf_counter() - load_start) * 1000.0 + + prompt_semantic_task = asyncio.create_task( + self.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr) + ) + + def _build_ref_spec_profile(): + with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats: + ref_spec_start = time.perf_counter() + refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + return refer_spec, { + "ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]), + "ref_spec_ms": float(ref_spec_ms), + "audio_stage_slots": float(ref_spec_limiter_stats["slots"]), + "audio_stage_inflight_peak": float(ref_spec_limiter_stats["peak_inflight"]), + } + + ref_spec_task = asyncio.create_task(asyncio.to_thread(_build_ref_spec_profile)) + (prompt_semantic, prompt_semantic_profile), (refer_spec, ref_spec_profile) = await asyncio.gather( + prompt_semantic_task, + ref_spec_task, + ) + + prompt_semantic_ms = ( + float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) + ) + audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float( + ref_spec_profile.get("ref_spec_wait_ms", 0.0) + ) + audio_stage_slots = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + float(ref_spec_profile.get("audio_stage_slots", 0.0)), + ) + audio_stage_inflight_peak = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + float(ref_spec_profile.get("audio_stage_inflight_peak", 0.0)), + ) + return { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": float(load_ms), + "audio_stage_wait_ms": float(audio_stage_wait_ms), + "audio_stage_slots": float(audio_stage_slots), + "audio_stage_inflight_peak": float(audio_stage_inflight_peak), + "prompt_semantic_ms": float(prompt_semantic_ms), + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float(prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)), + "ref_spec_wait_ms": float(ref_spec_profile.get("ref_spec_wait_ms", 0.0)), + "ref_spec_ms": float(ref_spec_profile.get("ref_spec_ms", 0.0)), + "bundle_total_ms": float(load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_profile.get("ref_spec_ms", 0.0)), + }, + } + def extract_text_features(self, text: str, language: str, profile: dict | None = None): return self.text_preprocessor.segment_and_extract_feature_for_text( text, language, self.configs.version, profile=profile diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 6bee49be..01a8ea4d 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -118,6 +118,15 @@ class TextPreprocessor: self.bert_stage_limiter = bert_stage_limiter self.bert_batch_worker = bert_batch_worker + def snapshot(self) -> Dict[str, object]: + return { + "device": str(self.device), + "bert_stage_limiter": ( + None if self.bert_stage_limiter is None else dict(self.bert_stage_limiter.snapshot()) + ), + "bert_batch_worker": None if self.bert_batch_worker is None else dict(self.bert_batch_worker.snapshot()), + } + def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: print(f"############ {i18n('切分文本')} ############") text = self.replace_consecutive_punctuation(text) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py index 06a5e1b8..71134268 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py @@ -24,6 +24,7 @@ class ProfiledResult: submit_at: float started_at: float finished_at: float + profile: Dict[str, float] | None = None @property def queue_ms(self) -> float: @@ -48,6 +49,52 @@ class PreparedCpuStage: target_cpu_profiled: ProfiledResult +class AsyncStageGate: + def __init__(self, max_inflight: int, poll_ms: int = 1): + self.max_inflight = max(0, int(max_inflight)) + self.lock = threading.Lock() + self.poll_s = max(0.0005, float(max(1, int(poll_ms))) / 1000.0) + self.inflight = 0 + self.peak_inflight = 0 + self.total_entered = 0 + self.total_wait_ms = 0.0 + self.wait_peak_ms = 0.0 + + async def acquire(self) -> Dict[str, float]: + wait_start = time.perf_counter() + while True: + with self.lock: + if self.max_inflight <= 0 or self.inflight < self.max_inflight: + self.inflight += 1 + self.total_entered += 1 + wait_ms = max(0.0, (time.perf_counter() - wait_start) * 1000.0) + self.total_wait_ms += float(wait_ms) + self.wait_peak_ms = max(self.wait_peak_ms, float(wait_ms)) + self.peak_inflight = max(self.peak_inflight, self.inflight) + return { + "wait_ms": float(wait_ms), + "inflight": float(self.inflight), + "peak_inflight": float(self.peak_inflight), + "max_inflight": float(self.max_inflight), + } + await asyncio.sleep(self.poll_s) + + def release(self) -> None: + with self.lock: + self.inflight = max(0, self.inflight - 1) + + def snapshot(self) -> Dict[str, float]: + with self.lock: + return { + "max_inflight": float(self.max_inflight), + "inflight": float(self.inflight), + "peak_inflight": float(self.peak_inflight), + "total_entered": float(self.total_entered), + "total_wait_ms": float(self.total_wait_ms), + "wait_peak_ms": float(self.wait_peak_ms), + } + + class PrepareCoordinator: def __init__(self, tts: Any): self.tts = tts @@ -59,7 +106,8 @@ class PrepareCoordinator: and os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_DIRECT", "0") != "0" ) self.max_inflight = max(0, int(os.environ.get("GPTSOVITS_PREPARE_MAX_INFLIGHT", "0"))) - self._inflight_semaphore = asyncio.Semaphore(self.max_inflight) if self.max_inflight > 0 else None + gate_poll_ms = int(os.environ.get("GPTSOVITS_PREPARE_GATE_POLL_MS", "1")) + self._inflight_gate = AsyncStageGate(self.max_inflight, poll_ms=gate_poll_ms) self.text_feature_workers = 0 self.text_feature_executor = None if not self.use_async_text_feature_path: @@ -81,6 +129,29 @@ class PrepareCoordinator: max_workers=self.ref_audio_workers, thread_name_prefix="prepare-ref-audio", ) + text_cpu_gate_default = max(0, int(getattr(tts, "prepare_text_cpu_workers", 0) or 0)) + text_feature_gate_default = max(0, int(self.text_feature_workers)) + ref_audio_gate_default = max(0, int(self.ref_audio_workers)) + self.text_cpu_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_INFLIGHT", str(text_cpu_gate_default))), + poll_ms=gate_poll_ms, + ) + self.text_feature_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_MAX_INFLIGHT", str(text_feature_gate_default))), + poll_ms=gate_poll_ms, + ) + self.ref_audio_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_REF_MAX_INFLIGHT", str(ref_audio_gate_default))), + poll_ms=gate_poll_ms, + ) + self.ref_load_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_REF_LOAD_MAX_INFLIGHT", str(ref_audio_gate_default))), + poll_ms=gate_poll_ms, + ) + self.ref_spec_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_REF_SPEC_MAX_INFLIGHT", str(ref_audio_gate_default))), + poll_ms=gate_poll_ms, + ) def _mark_enter(self) -> Tuple[int, int]: with self.lock: @@ -94,15 +165,29 @@ class PrepareCoordinator: with self.lock: self.inflight = max(0, self.inflight - 1) - def snapshot(self) -> Dict[str, int]: + def snapshot(self) -> Dict[str, Any]: with self.lock: - return { + snapshot: Dict[str, Any] = { "inflight": int(self.inflight), "peak_inflight": int(self.peak_inflight), "max_inflight": int(self.max_inflight), "text_feature_workers": int(self.text_feature_workers), "ref_audio_workers": int(self.ref_audio_workers), } + runtime_snapshot_fn = getattr(self.tts, "snapshot_prepare_runtime_components", None) + if callable(runtime_snapshot_fn): + try: + snapshot["prepare_runtime_state"] = runtime_snapshot_fn() + except Exception: + snapshot["prepare_runtime_state"] = None + snapshot["prepare_stage_gates"] = { + "text_cpu": self.text_cpu_gate.snapshot(), + "text_feature": self.text_feature_gate.snapshot(), + "ref_audio": self.ref_audio_gate.snapshot(), + "ref_load": self.ref_load_gate.snapshot(), + "ref_spec": self.ref_spec_gate.snapshot(), + } + return snapshot @staticmethod def _run_profiled(fn, submit_at: float, *args) -> ProfiledResult: @@ -119,6 +204,12 @@ class PrepareCoordinator: def _prepare_text_cpu(self, text: str, language: str): return self.tts.prepare_text_segments(text, language) + def _load_ref_audio_raw(self, ref_audio_path: str): + return self.tts._load_ref_audio_raw(ref_audio_path) + + def _extract_ref_spec_from_raw(self, raw_audio, raw_sr: int): + return self.tts._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + @staticmethod def _build_empty_text_features_like(reference: PreparedTextFeatures | None = None) -> PreparedTextFeatures: feature_dim = 1024 @@ -155,17 +246,54 @@ class PrepareCoordinator: return await loop.run_in_executor(executor, self._run_profiled, fn, float(submit_at), *args) async def _run_text_cpu_stage(self, text: str, language: str) -> ProfiledResult: + await self.text_cpu_gate.acquire() if text in [None, ""]: - submit_at = time.perf_counter() - return ProfiledResult(result=[], submit_at=submit_at, started_at=submit_at, finished_at=submit_at) + try: + submit_at = time.perf_counter() + return ProfiledResult(result=[], submit_at=submit_at, started_at=submit_at, finished_at=submit_at) + finally: + self.text_cpu_gate.release() + text_cpu_worker = getattr(self.tts, "prepare_text_cpu_worker", None) executor = getattr(self.tts, "prepare_text_cpu_executor", None) - if executor is None: - submit_at = time.perf_counter() - return self._run_profiled(self._prepare_text_cpu, submit_at, text, language) - return await self._run_on_executor(executor, self._prepare_text_cpu, text, language) + try: + if text_cpu_worker is not None: + submit_at = time.perf_counter() + result, worker_profile = await text_cpu_worker.submit_async(text, language) + started_at = float( + submit_at + + ( + float(worker_profile.get("text_cpu_admission_wait_ms", 0.0)) + + float(worker_profile.get("text_cpu_queue_wait_ms", 0.0)) + ) + / 1000.0 + ) + finished_at = float(started_at + float(worker_profile.get("text_cpu_run_ms", 0.0)) / 1000.0) + return ProfiledResult( + result=result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=finished_at, + profile=dict(worker_profile), + ) + if executor is None: + submit_at = time.perf_counter() + return self._run_profiled(self._prepare_text_cpu, submit_at, text, language) + return await self._run_on_executor(executor, self._prepare_text_cpu, text, language) + finally: + self.text_cpu_gate.release() async def _run_text_feature_stage(self, prepared_segments, language: str, cpu_run_ms: float) -> ProfiledResult: - return await self._run_on_executor(self.text_feature_executor, self._build_text_features, prepared_segments, language, cpu_run_ms) + await self.text_feature_gate.acquire() + try: + return await self._run_on_executor( + self.text_feature_executor, + self._build_text_features, + prepared_segments, + language, + cpu_run_ms, + ) + finally: + self.text_feature_gate.release() @staticmethod def _estimate_text_feature_run_ms(profile: Dict[str, float]) -> float: @@ -199,16 +327,34 @@ class PrepareCoordinator: ) return prompt_profiled, target_profiled + await self.text_feature_gate.acquire() target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)} submit_at = time.perf_counter() started_at = float(submit_at) - if prompt_is_empty: - target_result_raw = await self.tts.build_text_features_from_segments_async( - target_segments, - profile=target_profile, - ) - prompt_result = self._build_empty_text_features_like( - PreparedTextFeatures( + try: + if prompt_is_empty: + target_result_raw = await self.tts.build_text_features_from_segments_async( + target_segments, + profile=target_profile, + ) + prompt_result = self._build_empty_text_features_like( + PreparedTextFeatures( + phones=target_result_raw[0], + bert_features=target_result_raw[1], + norm_text=target_result_raw[2], + profile=target_profile, + total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), + cpu_preprocess_ms=float(target_cpu_run_ms), + ) + ) + finished_at = time.perf_counter() + prompt_profiled = ProfiledResult( + result=prompt_result, + submit_at=float(submit_at), + started_at=float(submit_at), + finished_at=float(submit_at), + ) + target_result = PreparedTextFeatures( phones=target_result_raw[0], bert_features=target_result_raw[1], norm_text=target_result_raw[2], @@ -216,13 +362,37 @@ class PrepareCoordinator: total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), cpu_preprocess_ms=float(target_cpu_run_ms), ) + target_profiled = ProfiledResult( + result=target_result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0), + ) + if finished_at > target_profiled.finished_at: + target_result.profile["bert_total_ms"] = max( + self._estimate_text_feature_run_ms(target_profile), + (finished_at - submit_at) * 1000.0, + ) + else: + target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) + return prompt_profiled, target_profiled + + prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)} + prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async( + prompt_segments, + target_segments, + prompt_profile=prompt_profile, + target_profile=target_profile, ) finished_at = time.perf_counter() - prompt_profiled = ProfiledResult( - result=prompt_result, - submit_at=float(submit_at), - started_at=float(submit_at), - finished_at=float(submit_at), + + prompt_result = PreparedTextFeatures( + phones=prompt_result_raw[0], + bert_features=prompt_result_raw[1], + norm_text=prompt_result_raw[2], + profile=prompt_profile, + total_ms=float(prompt_cpu_run_ms + self._estimate_text_feature_run_ms(prompt_profile)), + cpu_preprocess_ms=float(prompt_cpu_run_ms), ) target_result = PreparedTextFeatures( phones=target_result_raw[0], @@ -232,79 +402,152 @@ class PrepareCoordinator: total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), cpu_preprocess_ms=float(target_cpu_run_ms), ) + prompt_profiled = ProfiledResult( + result=prompt_result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(submit_at + self._estimate_text_feature_run_ms(prompt_profile) / 1000.0), + ) target_profiled = ProfiledResult( result=target_result, submit_at=float(submit_at), started_at=started_at, finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0), ) - if finished_at > target_profiled.finished_at: + if finished_at > prompt_profiled.finished_at: + prompt_result.profile["bert_total_ms"] = max( + self._estimate_text_feature_run_ms(prompt_profile), + (finished_at - submit_at) * 1000.0, + ) target_result.profile["bert_total_ms"] = max( self._estimate_text_feature_run_ms(target_profile), (finished_at - submit_at) * 1000.0, ) else: + prompt_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(prompt_profile) target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) return prompt_profiled, target_profiled - - prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)} - prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async( - prompt_segments, - target_segments, - prompt_profile=prompt_profile, - target_profile=target_profile, - ) - finished_at = time.perf_counter() - - prompt_result = PreparedTextFeatures( - phones=prompt_result_raw[0], - bert_features=prompt_result_raw[1], - norm_text=prompt_result_raw[2], - profile=prompt_profile, - total_ms=float(prompt_cpu_run_ms + self._estimate_text_feature_run_ms(prompt_profile)), - cpu_preprocess_ms=float(prompt_cpu_run_ms), - ) - target_result = PreparedTextFeatures( - phones=target_result_raw[0], - bert_features=target_result_raw[1], - norm_text=target_result_raw[2], - profile=target_profile, - total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), - cpu_preprocess_ms=float(target_cpu_run_ms), - ) - prompt_profiled = ProfiledResult( - result=prompt_result, - submit_at=float(submit_at), - started_at=started_at, - finished_at=float(submit_at + self._estimate_text_feature_run_ms(prompt_profile) / 1000.0), - ) - target_profiled = ProfiledResult( - result=target_result, - submit_at=float(submit_at), - started_at=started_at, - finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0), - ) - if finished_at > prompt_profiled.finished_at: - prompt_result.profile["bert_total_ms"] = max( - self._estimate_text_feature_run_ms(prompt_profile), - (finished_at - submit_at) * 1000.0, - ) - target_result.profile["bert_total_ms"] = max( - self._estimate_text_feature_run_ms(target_profile), - (finished_at - submit_at) * 1000.0, - ) - else: - prompt_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(prompt_profile) - target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) - return prompt_profiled, target_profiled + finally: + self.text_feature_gate.release() async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult: - return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path) + if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None: + submit_at = time.perf_counter() + started_at = float(submit_at) + + await self.ref_load_gate.acquire() + try: + load_profiled = await self._run_on_executor(self.ref_audio_executor, self._load_ref_audio_raw, ref_audio_path) + finally: + self.ref_load_gate.release() + + raw_audio, raw_sr = load_profiled.result + prompt_semantic_task = asyncio.create_task( + self.tts.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr) + ) + await self.ref_spec_gate.acquire() + try: + ref_spec_task = asyncio.create_task( + self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr) + ) + (prompt_semantic, prompt_semantic_profile), ref_spec_profiled = await asyncio.gather( + prompt_semantic_task, + ref_spec_task, + ) + finally: + self.ref_spec_gate.release() + + refer_spec = ref_spec_profiled.result + limiter_snapshot = ( + self.tts.prepare_ref_audio_stage_limiter.snapshot() + if getattr(self.tts, "prepare_ref_audio_stage_limiter", None) is not None + else {} + ) + prompt_semantic_ms = ( + float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) + ) + audio_stage_wait_ms = ( + float(load_profiled.queue_ms) + + float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + + float(ref_spec_profiled.queue_ms) + ) + finished_at = time.perf_counter() + result = { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_queue_ms": float(load_profiled.queue_ms), + "audio_load_ms": float(load_profiled.run_ms), + "audio_stage_wait_ms": float(audio_stage_wait_ms), + "audio_stage_slots": float( + max( + float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + float(limiter_snapshot.get("slots", 0.0)), + ) + ), + "audio_stage_inflight_peak": float( + max( + float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + float(limiter_snapshot.get("peak_inflight", 0.0)), + ) + ), + "prompt_semantic_ms": float(prompt_semantic_ms), + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + ), + "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float( + prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) + ), + "ref_spec_wait_ms": float(ref_spec_profiled.queue_ms), + "ref_spec_ms": float(ref_spec_profiled.run_ms), + "bundle_total_ms": float( + load_profiled.queue_ms + + load_profiled.run_ms + + prompt_semantic_ms + + ref_spec_profiled.queue_ms + + ref_spec_profiled.run_ms + ), + }, + } + return ProfiledResult( + result=result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(finished_at), + ) + + await self.ref_audio_gate.acquire() + try: + if hasattr(self.tts, "extract_ref_audio_bundle_async"): + submit_at = time.perf_counter() + started_at = time.perf_counter() + result = await self.tts.extract_ref_audio_bundle_async(ref_audio_path) + finished_at = time.perf_counter() + return ProfiledResult( + result=result, + submit_at=float(submit_at), + started_at=float(started_at), + finished_at=float(finished_at), + ) + return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path) + finally: + self.ref_audio_gate.release() def _release_split_stage_slot(self) -> None: self._mark_leave() - if self._inflight_semaphore is not None: - self._inflight_semaphore.release() + self._inflight_gate.release() async def prepare_cpu_stage_profiled_async( self, @@ -312,9 +555,11 @@ class PrepareCoordinator: prepare_submit_at: float, ) -> PreparedCpuStage: admission_start = time.perf_counter() - if self._inflight_semaphore is not None: - await self._inflight_semaphore.acquire() - prepare_admission_wait_ms = max(0.0, (time.perf_counter() - admission_start) * 1000.0) + admission_stats = await self._inflight_gate.acquire() + prepare_admission_wait_ms = max( + float(admission_stats.get("wait_ms", 0.0)), + (time.perf_counter() - admission_start) * 1000.0, + ) current_inflight, peak_inflight = self._mark_enter() prepare_start = time.perf_counter() prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) @@ -382,10 +627,28 @@ class PrepareCoordinator: "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, "prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms, "prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms, + "prompt_text_cpu_admission_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) + ), + "prompt_text_cpu_backpressure_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) + ), + "prompt_text_cpu_capacity_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) + ), "prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms, "prompt_text_feature_run_ms": prompt_feature_profiled.run_ms, "text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms, "text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms, + "text_cpu_admission_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) + ), + "text_cpu_backpressure_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) + ), + "text_cpu_capacity_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) + ), "text_feature_queue_ms": target_feature_profiled.queue_ms, "text_feature_run_ms": target_feature_profiled.run_ms, "ref_audio_task_queue_ms": ref_audio_profiled.queue_ms, diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py index 7a1f9a53..64ca133f 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py @@ -1,3 +1,4 @@ +import asyncio import threading import time import uuid @@ -51,6 +52,8 @@ class RefSemanticTask: task_id: str = field(default_factory=lambda: uuid.uuid4().hex) created_at: float = field(default_factory=time.perf_counter) done_event: threading.Event = field(default_factory=threading.Event) + done_loop: asyncio.AbstractEventLoop | None = None + done_future: asyncio.Future | None = None result_prompt_semantic: torch.Tensor | None = None error: Exception | None = None profile: Dict[str, float] = field(default_factory=dict) @@ -115,6 +118,41 @@ class PrepareRefSemanticBatchWorker: assert task.result_prompt_semantic is not None return task.result_prompt_semantic, dict(task.profile) + async def submit_async(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]: + loop = asyncio.get_running_loop() + task = RefSemanticTask( + raw_audio=raw_audio, + raw_sr=int(raw_sr), + done_loop=loop, + done_future=loop.create_future(), + ) + with self.condition: + self.pending_tasks.append(task) + self.total_submitted += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + return await task.done_future + + @staticmethod + def _resolve_done_future(task: RefSemanticTask) -> None: + if task.done_future is None or task.done_future.done(): + return + if task.error is not None: + task.done_future.set_exception(task.error) + return + assert task.result_prompt_semantic is not None + task.done_future.set_result((task.result_prompt_semantic, dict(task.profile))) + + def _notify_task_done(self, task: RefSemanticTask) -> None: + task.done_event.set() + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_done_future, task) + except RuntimeError: + pass + def snapshot(self) -> Dict[str, int]: with self.condition: return { @@ -247,7 +285,7 @@ class PrepareRefSemanticBatchWorker: for task in batch: if task.result_prompt_semantic is not None: task.profile["prompt_semantic_scatter_ms"] = float(scatter_ms) - task.done_event.set() + self._notify_task_done(task) def _run_loop(self) -> None: while True: @@ -257,6 +295,6 @@ class PrepareRefSemanticBatchWorker: except Exception as exc: # noqa: PERF203 for task in batch: task.error = exc - task.done_event.set() + self._notify_task_done(task) finally: self._finalize_batch(batch) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py new file mode 100644 index 00000000..b7c985f0 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py @@ -0,0 +1,215 @@ +import asyncio +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Callable, Deque, Dict, Tuple + + +@dataclass +class TextCpuTask: + text: str + language: str + task_id: str = field(default_factory=lambda: uuid.uuid4().hex) + created_at: float = field(default_factory=time.perf_counter) + enqueued_at: float = 0.0 + admission_wait_ms: float = 0.0 + backpressure_wait_ms: float = 0.0 + capacity_wait_ms: float = 0.0 + pending_depth_on_enqueue: int = 0 + done_event: threading.Event = field(default_factory=threading.Event) + done_loop: asyncio.AbstractEventLoop | None = None + done_future: asyncio.Future | None = None + result: Any = None + error: Exception | None = None + profile: Dict[str, float] = field(default_factory=dict) + + +class PrepareTextCpuWorker: + def __init__( + self, + process_fn: Callable[[str, str], Any], + worker_count: int, + max_pending_tasks: int = 0, + admission_poll_ms: int = 1, + admission_controller: Callable[[], Dict[str, float | int | bool]] | None = None, + ) -> None: + self.process_fn = process_fn + self.worker_count = max(1, int(worker_count)) + self.max_pending_tasks = max(0, int(max_pending_tasks)) + self.admission_poll_s = max(0.0005, float(max(1, int(admission_poll_ms))) / 1000.0) + self.admission_controller = admission_controller + + self.condition = threading.Condition() + self.pending_tasks: Deque[TextCpuTask] = deque() + self.pending_peak = 0 + self.total_submitted = 0 + self.total_finished = 0 + self.active_workers = 0 + self.active_workers_peak = 0 + self.admission_wait_total_ms = 0.0 + self.admission_wait_peak_ms = 0.0 + self.backpressure_wait_total_ms = 0.0 + self.backpressure_wait_peak_ms = 0.0 + self.capacity_wait_total_ms = 0.0 + self.capacity_wait_peak_ms = 0.0 + self.backpressure_blocked_total = 0 + + self.worker_threads = [ + threading.Thread(target=self._run_loop, name=f"prepare-text-cpu-worker-{index}", daemon=True) + for index in range(self.worker_count) + ] + for thread in self.worker_threads: + thread.start() + + def _can_enqueue_locked(self) -> bool: + if self.max_pending_tasks <= 0: + return True + return (len(self.pending_tasks) + self.active_workers) < self.max_pending_tasks + + def _get_admission_state(self) -> Dict[str, float | int | bool]: + if self.admission_controller is None: + return {"blocked": False} + try: + state = dict(self.admission_controller() or {}) + except Exception: + return {"blocked": False} + state["blocked"] = bool(state.get("blocked", False)) + return state + + def _record_enqueue_locked( + self, + task: TextCpuTask, + *, + admission_wait_ms: float, + backpressure_wait_ms: float, + capacity_wait_ms: float, + ) -> None: + task.admission_wait_ms = float(max(0.0, admission_wait_ms)) + task.backpressure_wait_ms = float(max(0.0, backpressure_wait_ms)) + task.capacity_wait_ms = float(max(0.0, capacity_wait_ms)) + task.enqueued_at = time.perf_counter() + task.pending_depth_on_enqueue = int(len(self.pending_tasks)) + self.pending_tasks.append(task) + self.total_submitted += 1 + self.admission_wait_total_ms += task.admission_wait_ms + self.admission_wait_peak_ms = max(self.admission_wait_peak_ms, task.admission_wait_ms) + self.backpressure_wait_total_ms += task.backpressure_wait_ms + self.backpressure_wait_peak_ms = max(self.backpressure_wait_peak_ms, task.backpressure_wait_ms) + self.capacity_wait_total_ms += task.capacity_wait_ms + self.capacity_wait_peak_ms = max(self.capacity_wait_peak_ms, task.capacity_wait_ms) + if task.backpressure_wait_ms > 0.0: + self.backpressure_blocked_total += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + + async def _enqueue_task_async(self, task: TextCpuTask) -> None: + admission_started = time.perf_counter() + backpressure_wait_ms = 0.0 + capacity_wait_ms = 0.0 + while True: + loop_start = time.perf_counter() + admission_state = self._get_admission_state() + blocked = bool(admission_state.get("blocked", False)) + with self.condition: + if not blocked and self._can_enqueue_locked(): + self._record_enqueue_locked( + task, + admission_wait_ms=(time.perf_counter() - admission_started) * 1000.0, + backpressure_wait_ms=backpressure_wait_ms, + capacity_wait_ms=capacity_wait_ms, + ) + return + await asyncio.sleep(self.admission_poll_s) + waited_ms = (time.perf_counter() - loop_start) * 1000.0 + if blocked: + backpressure_wait_ms += waited_ms + else: + capacity_wait_ms += waited_ms + + def submit(self, text: str, language: str) -> Tuple[Any, Dict[str, float]]: + task = TextCpuTask(text=str(text), language=str(language)) + asyncio.run(self._enqueue_task_async(task)) + task.done_event.wait() + if task.error is not None: + raise task.error + return task.result, dict(task.profile) + + async def submit_async(self, text: str, language: str) -> Tuple[Any, Dict[str, float]]: + loop = asyncio.get_running_loop() + task = TextCpuTask( + text=str(text), + language=str(language), + done_loop=loop, + done_future=loop.create_future(), + ) + await self._enqueue_task_async(task) + return await task.done_future + + @staticmethod + def _resolve_done_future(task: TextCpuTask) -> None: + if task.done_future is None or task.done_future.done(): + return + if task.error is not None: + task.done_future.set_exception(task.error) + return + task.done_future.set_result((task.result, dict(task.profile))) + + def _notify_task_done(self, task: TextCpuTask) -> None: + task.done_event.set() + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_done_future, task) + except RuntimeError: + pass + + def snapshot(self) -> Dict[str, int | float]: + with self.condition: + return { + "worker_count": int(self.worker_count), + "pending": int(len(self.pending_tasks)), + "pending_peak": int(self.pending_peak), + "active_workers": int(self.active_workers), + "active_workers_peak": int(self.active_workers_peak), + "total_submitted": int(self.total_submitted), + "total_finished": int(self.total_finished), + "max_pending_tasks": int(self.max_pending_tasks), + "admission_wait_total_ms": float(self.admission_wait_total_ms), + "admission_wait_peak_ms": float(self.admission_wait_peak_ms), + "backpressure_wait_total_ms": float(self.backpressure_wait_total_ms), + "backpressure_wait_peak_ms": float(self.backpressure_wait_peak_ms), + "capacity_wait_total_ms": float(self.capacity_wait_total_ms), + "capacity_wait_peak_ms": float(self.capacity_wait_peak_ms), + "backpressure_blocked_total": int(self.backpressure_blocked_total), + } + + def _run_loop(self) -> None: + while True: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + task = self.pending_tasks.popleft() + self.active_workers += 1 + self.active_workers_peak = max(self.active_workers_peak, self.active_workers) + started_at = time.perf_counter() + try: + task.result = self.process_fn(task.text, task.language) + task.profile = { + "text_cpu_admission_wait_ms": float(task.admission_wait_ms), + "text_cpu_backpressure_wait_ms": float(task.backpressure_wait_ms), + "text_cpu_capacity_wait_ms": float(task.capacity_wait_ms), + "text_cpu_queue_wait_ms": max(0.0, (started_at - task.enqueued_at) * 1000.0), + "text_cpu_pending_depth_on_enqueue": float(task.pending_depth_on_enqueue), + "text_cpu_run_ms": max(0.0, (time.perf_counter() - started_at) * 1000.0), + } + except Exception as exc: # noqa: PERF203 + task.error = exc + finally: + with self.condition: + self.active_workers = max(0, self.active_workers - 1) + self.total_finished += 1 + self.condition.notify_all() + self._notify_task_done(task) diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index 78dbfc36..ed465f69 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -421,6 +421,55 @@ def _iter_contiguous_sampling_groups( return groups +def _uniform_sampling_group_key(active_batch: T2SActiveBatch) -> Optional[Tuple[int, float, float, float, bool]]: + if not active_batch.states: + return None + if active_batch.step_indices.numel() <= 0: + return None + first_step_index = int(active_batch.step_indices[0].item()) + if bool((active_batch.step_indices != first_step_index).any().item()): + return None + first_state = active_batch.states[0] + first_key = _sampling_group_key( + top_k=first_state.top_k, + top_p=first_state.top_p, + temperature=first_state.temperature, + repetition_penalty=first_state.repetition_penalty, + trim_eos=first_step_index < 11, + ) + for state in active_batch.states[1:]: + if ( + state.top_k != first_state.top_k + or state.top_p != first_state.top_p + or state.temperature != first_state.temperature + or state.repetition_penalty != first_state.repetition_penalty + ): + return None + return first_key + + +def _batched_sample_uniform( + logits: torch.Tensor, + histories: Sequence[torch.LongTensor], + sampling_key: Tuple[int, float, float, float, bool], +) -> Tuple[torch.Tensor, torch.Tensor]: + top_k, top_p, temperature, repetition_penalty, trim_eos = sampling_key + sample_logits = logits[:, :-1] if trim_eos else logits + padded_histories, history_mask = _pad_token_sequences(histories) + probs = logits_to_probs( + logits=sample_logits, + previous_tokens=padded_histories, + previous_token_mask=history_mask, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + sampled = multinomial_sample_one_no_sync(probs) + argmax_tokens = torch.argmax(sample_logits, dim=-1) + return sampled, argmax_tokens + + def _batched_sample_by_group( logits: torch.Tensor, histories: Sequence[torch.LongTensor], @@ -594,27 +643,59 @@ def _sample_per_request( keep_indices: List[int] = [] updated_sequences: List[torch.LongTensor] = [] - sampling_keys = [ - _sampling_group_key( - top_k=state.top_k, - top_p=state.top_p, - temperature=state.temperature, - repetition_penalty=state.repetition_penalty, - trim_eos=int(active_batch.step_indices[batch_index].item()) < 11, + uniform_sampling_key = _uniform_sampling_group_key(active_batch) + sampled_items: List[torch.Tensor] + argmax_tokens: List[int] + sampled_token_tensor: Optional[torch.Tensor] = None + argmax_token_tensor: Optional[torch.Tensor] = None + if uniform_sampling_key is not None: + sampled_tensor, argmax_tensor = _batched_sample_uniform( + logits=logits, + histories=active_batch.y_sequences, + sampling_key=uniform_sampling_key, + ) + sampled_token_tensor = sampled_tensor.view(-1) + argmax_token_tensor = argmax_tensor.view(-1) + if ( + all(state.early_stop_num == -1 for state in active_batch.states) + and int(active_batch.step_indices[0].item()) + 1 < max_steps + and not bool(sampled_token_tensor.eq(model.EOS).any().item()) + and not bool(argmax_token_tensor.eq(model.EOS).any().item()) + ): + return ( + [], + list(range(len(active_batch.states))), + [torch.cat([history, sampled_token_tensor[index : index + 1]], dim=0) for index, history in enumerate(active_batch.y_sequences)], + ) + sampled_items = [sampled_tensor[index : index + 1] for index in range(sampled_tensor.shape[0])] + argmax_tokens = [int(item) for item in argmax_tensor.tolist()] + else: + sampling_keys = [ + _sampling_group_key( + top_k=state.top_k, + top_p=state.top_p, + temperature=state.temperature, + repetition_penalty=state.repetition_penalty, + trim_eos=int(active_batch.step_indices[batch_index].item()) < 11, + ) + for batch_index, state in enumerate(active_batch.states) + ] + sampled_items, argmax_tokens = _batched_sample_by_group( + logits=logits, + histories=active_batch.y_sequences, + sampling_keys=sampling_keys, ) - for batch_index, state in enumerate(active_batch.states) - ] - sampled_items, argmax_tokens = _batched_sample_by_group( - logits=logits, - histories=active_batch.y_sequences, - sampling_keys=sampling_keys, - ) for batch_index, state in enumerate(active_batch.states): step_index = int(active_batch.step_indices[batch_index].item()) current_history = active_batch.y_sequences[batch_index] - sampled = sampled_items[batch_index] - sampled_token = int(sampled[0, 0].item()) - argmax_token = argmax_tokens[batch_index] + if sampled_token_tensor is not None and argmax_token_tensor is not None: + sampled = sampled_token_tensor[batch_index : batch_index + 1] + sampled_token = int(sampled_token_tensor[batch_index].item()) + argmax_token = int(argmax_token_tensor[batch_index].item()) + else: + sampled = sampled_items[batch_index] + sampled_token = int(sampled[0, 0].item()) + argmax_token = argmax_tokens[batch_index] new_history = torch.cat([current_history, sampled.view(-1)], dim=0) finish_reason: Optional[str] = None @@ -690,6 +771,13 @@ def decode_one_step( finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps) if len(keep_indices) == 0: return None, finished_items + if len(keep_indices) == len(active_batch.request_ids): + active_batch.y_sequences = updated_sequences + active_batch.step_indices = active_batch.step_indices + 1 + if not was_prefill and active_batch.kv_lens is not None: + active_batch.kv_lens = active_batch.kv_lens + 1 + active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences) + return active_batch, finished_items device = logits.device keep_tensor = torch.LongTensor(keep_indices).to(device) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py index f950c68d..e31c5dfe 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py @@ -314,6 +314,7 @@ def build_scheduler_submit_headers( "X-Prepare-Text-Pair-Wall-Ms": format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)), "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), "X-Prepare-Engine-GPU-Queue-Wait-Ms": format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)), + "X-Prepare-Engine-GPU-Batch-Size": str(int(prepare_profile.get("engine_gpu_prepare_batch_size", 0.0))), "X-Prepare-Audio-Load-Ms": format_ms_header(prepare_profile.get("audio_load_ms", 0.0)), "X-Prepare-Audio-Stage-Wait-Ms": format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)), "X-Prepare-Prompt-Semantic-Ms": format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)), diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py index 7f4e485f..15eedeca 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py @@ -44,6 +44,16 @@ class EngineTaskQueueOwner: return None return self.queue.popleft() + def pop_left_many(self, max_items: int) -> List[Any]: + limit = max(1, int(max_items)) + with self.condition: + if not self.queue: + return [] + selected: List[Any] = [] + while self.queue and len(selected) < limit: + selected.append(self.queue.popleft()) + return selected + def mark_completed(self, count: int = 1, *, notify: bool = False) -> None: if count <= 0: return @@ -315,6 +325,7 @@ class EngineGpuPrepareTask: engine_request_id: str | None enqueue_time: float queue_wait_ms: float = 0.0 + admission_wait_ms: float = 0.0 error: str | None = None diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py index bb3e8b06..b9095d2c 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import os import time from typing import Any @@ -9,6 +10,19 @@ from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineGpuPrepare class EnginePrepareStageMixin: + async def _wait_prepare_queue_admission(self) -> float: + soft_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_SOFT_MAX", "0"))) + if soft_max <= 0: + return 0.0 + poll_s = max( + 0.0005, + float(max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_ADMISSION_POLL_MS", "1")))) / 1000.0, + ) + wait_start = time.perf_counter() + while self.prepare_queue_owner.waiting_count() >= soft_max: + await asyncio.sleep(poll_s) + return max(0.0, (time.perf_counter() - wait_start) * 1000.0) + async def prepare_state_via_engine_gpu_queue( self, *, @@ -16,12 +30,14 @@ class EnginePrepareStageMixin: prepare_submit_at: float, engine_request_id: str | None, ) -> tuple[T2SRequestState, float, float]: + prepare_queue_admission_wait_ms = await self._wait_prepare_queue_admission() cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) if engine_request_id not in [None, ""]: self.update_request_state( str(engine_request_id), EngineStatus.GPU_PREPARING, { + "engine_prepare_queue_admission_wait_ms": float(prepare_queue_admission_wait_ms), "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), @@ -37,31 +53,44 @@ class EnginePrepareStageMixin: done_future=done_future, engine_request_id=engine_request_id or spec.request_id, enqueue_time=time.perf_counter(), + admission_wait_ms=float(prepare_queue_admission_wait_ms), ) self.prepare_queue_owner.enqueue(task) self.notify_arbiter() return await done_future def run_engine_prepare_once(self) -> bool: - task = self.prepare_queue_owner.pop_left() - if task is None: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + tasks = self.prepare_queue_owner.pop_left_many(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) + if not tasks: return False - queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) - try: - state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( - self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) - ) + now = time.perf_counter() + queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks] + batch_results = asyncio.run( + self.scheduler_worker.prepare_gpu_stages_profiled_async([task.cpu_stage for task in tasks]) + ) + completed_count = 0 + for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results): + if isinstance(result, Exception): + task.error = str(result) + self.fail_request_state(task.engine_request_id or task.request_id, str(result)) + self._notify_prepare_error(task, result) + completed_count += 1 + continue + state, prepare_exec_started_at, prepare_exec_finished_at = result + state.prepare_profile["engine_prepare_queue_admission_wait_ms"] = float(task.admission_wait_ms) state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) + state.prepare_profile["engine_gpu_prepare_batch_size"] = float(len(tasks)) if task.engine_request_id not in [None, ""]: self.merge_request_state_profile( str(task.engine_request_id), - {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, + { + "engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms), + "engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms), + "engine_gpu_prepare_batch_size": float(len(tasks)), + }, ) - self.prepare_queue_owner.mark_completed(1) self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) - return True - except Exception as exc: - task.error = str(exc) - self.fail_request_state(task.engine_request_id or task.request_id, str(exc)) - self._notify_prepare_error(task, exc) - return True + completed_count += 1 + self.prepare_queue_owner.mark_completed(completed_count) + return True diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py index 28da24ee..8b3db1fa 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import os import time from typing import Callable, Dict, List @@ -32,6 +33,11 @@ class WorkerPrepareExecutor: def get_max_inflight(self) -> int: return int(self.coordinator.snapshot().get("max_inflight", 0)) + def get_batch_policy(self) -> Dict[str, int]: + return { + "prepare_batch_max_items": max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_BATCH_MAX_ITEMS", 8))), + } + def is_idle(self) -> bool: return int(self.coordinator.snapshot().get("inflight", 0)) <= 0 @@ -69,3 +75,17 @@ class WorkerPrepareExecutor: return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) finally: self._notify_state_change() + + async def prepare_gpu_stages_profiled_async( + self, + cpu_stages: List[PreparedCpuStage], + ) -> List[tuple[T2SRequestState, float, float] | Exception]: + try: + return list( + await asyncio.gather( + *[self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) for cpu_stage in cpu_stages], + return_exceptions=True, + ) + ) + finally: + self._notify_state_change() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py index f1910409..e498e9ea 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py @@ -32,6 +32,9 @@ class WorkerSubmitLifecycleMixin: def get_finalize_batch_policy(self) -> Dict[str, Any]: return dict(self.finalize_executor.get_batch_policy()) + def get_prepare_batch_policy(self) -> Dict[str, int]: + return dict(self.prepare_executor.get_batch_policy()) + def get_decode_runtime_counters(self) -> Dict[str, int]: with self.condition: return self.decode_runtime_tracker.get_counters() @@ -258,3 +261,9 @@ class WorkerSubmitLifecycleMixin: cpu_stage: PreparedCpuStage, ) -> tuple[T2SRequestState, float, float]: return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage) + + async def prepare_gpu_stages_profiled_async( + self, + cpu_stages: List[PreparedCpuStage], + ) -> List[tuple[T2SRequestState, float, float] | Exception]: + return await self.prepare_executor.prepare_gpu_stages_profiled_async(cpu_stages) diff --git a/third_party/g2pw-cu b/third_party/g2pw-cu new file mode 160000 index 00000000..a53cf4ee --- /dev/null +++ b/third_party/g2pw-cu @@ -0,0 +1 @@ +Subproject commit a53cf4eed5759f7b5d4563ce6e4b13557e054d98