diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..570e9d7b --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/g2pw-cu"] + path = third_party/g2pw-cu + url = https://github.com/baicai-1145/g2pw-cu.git diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index bd811d8a..92c829a1 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1,4 +1,5 @@ import gc +import asyncio import concurrent.futures import math import os @@ -42,6 +43,7 @@ from TTS_infer_pack.prepare_ref_semantic_batch_worker import ( PrepareRefSemanticBatchWorker, prepare_prompt_semantic_wav16k, ) +from TTS_infer_pack.prepare_text_cpu_worker import PrepareTextCpuWorker from sv import SV resample_transform_dict = {} @@ -454,18 +456,12 @@ class TTS: self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4"))) self.prepare_bert_batch_worker = None self.prepare_ref_semantic_batch_worker = None + self.prepare_text_cpu_worker = None self.prepare_text_cpu_workers = max( 0, int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", "0")), ) - self.prepare_text_cpu_executor = ( - concurrent.futures.ThreadPoolExecutor( - max_workers=self.prepare_text_cpu_workers, - thread_name_prefix="prepare-text-cpu", - ) - if self.prepare_text_cpu_workers > 0 - else None - ) + self.prepare_text_cpu_executor = None self._init_models() self.refresh_runtime_components() @@ -488,6 +484,7 @@ class TTS: def refresh_runtime_components(self): self.prepare_bert_batch_worker = None self.prepare_ref_semantic_batch_worker = None + self.prepare_text_cpu_worker = None if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0": self.prepare_bert_batch_worker = PrepareBertBatchWorker( bert_model=self.bert_model, @@ -535,6 +532,92 @@ class TTS: bert_stage_limiter=self.prepare_bert_stage_limiter, bert_batch_worker=self.prepare_bert_batch_worker, ) + if self.prepare_text_cpu_workers > 0: + self.prepare_text_cpu_worker = PrepareTextCpuWorker( + process_fn=lambda text, language: self.text_preprocessor.preprocess_text_segments( + text, + language, + self.configs.version, + ), + worker_count=self.prepare_text_cpu_workers, + max_pending_tasks=int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_PENDING_TASKS", "0")), + admission_poll_ms=int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_ADMISSION_POLL_MS", "1")), + admission_controller=self._build_text_cpu_admission_state, + ) + + @staticmethod + def _safe_queue_qsize(executor) -> int | None: + if executor is None: + return None + queue = getattr(executor, "_work_queue", None) + if queue is None or not hasattr(queue, "qsize"): + return None + try: + return int(queue.qsize()) + except Exception: + return None + + def snapshot_prepare_runtime_components(self) -> dict: + return { + "text_cpu": { + "workers": int(self.prepare_text_cpu_workers), + "queue_size": self._safe_queue_qsize(self.prepare_text_cpu_executor), + "enabled": bool(self.prepare_text_cpu_worker is not None or self.prepare_text_cpu_executor is not None), + "worker": ( + None if self.prepare_text_cpu_worker is None else dict(self.prepare_text_cpu_worker.snapshot()) + ), + "admission": self._build_text_cpu_admission_state(), + }, + "bert": { + "stage_limiter": dict(self.prepare_bert_stage_limiter.snapshot()), + "batch_worker": ( + None if self.prepare_bert_batch_worker is None else dict(self.prepare_bert_batch_worker.snapshot()) + ), + "batching_enabled": bool(self.prepare_bert_batch_worker is not None), + }, + "ref_semantic": { + "stage_limiter": dict(self.prepare_ref_audio_stage_limiter.snapshot()), + "batch_worker": ( + None + if self.prepare_ref_semantic_batch_worker is None + else dict(self.prepare_ref_semantic_batch_worker.snapshot()) + ), + "batching_enabled": bool(self.prepare_ref_semantic_batch_worker is not None), + }, + "text_preprocessor": ( + None if self.text_preprocessor is None or not hasattr(self.text_preprocessor, "snapshot") else self.text_preprocessor.snapshot() + ), + } + + def _build_text_cpu_admission_state(self) -> dict: + bert_pending_soft_max = max( + 0, + int( + os.environ.get( + "GPTSOVITS_PREPARE_TEXT_CPU_BERT_PENDING_SOFT_MAX", + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_PENDING_THRESHOLD", "32"), + ) + ), + ) + if self.prepare_bert_batch_worker is None or bert_pending_soft_max <= 0: + return { + "blocked": False, + "reason": "", + "bert_pending": 0, + "bert_active_batch_size": 0, + "bert_pending_soft_max": int(bert_pending_soft_max), + } + bert_state = dict(self.prepare_bert_batch_worker.snapshot()) + bert_pending = int(bert_state.get("pending", 0)) + bert_active_batch_size = int(bert_state.get("active_batch_size", 0)) + blocked = bert_pending >= bert_pending_soft_max + return { + "blocked": bool(blocked), + "reason": ("bert_pending" if blocked else ""), + "bert_pending": int(bert_pending), + "bert_active_batch_size": int(bert_active_batch_size), + "bert_pending_soft_max": int(bert_pending_soft_max), + } def _init_models( self, @@ -1040,6 +1123,79 @@ class TTS: }, } + async def extract_ref_audio_bundle_async(self, ref_audio_path: str): + if self.prepare_ref_semantic_batch_worker is None: + return await asyncio.to_thread(self.extract_ref_audio_bundle, ref_audio_path) + + load_start = time.perf_counter() + raw_audio, raw_sr = await asyncio.to_thread(self._load_ref_audio_raw, ref_audio_path) + load_ms = (time.perf_counter() - load_start) * 1000.0 + + prompt_semantic_task = asyncio.create_task( + self.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr) + ) + + def _build_ref_spec_profile(): + with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats: + ref_spec_start = time.perf_counter() + refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + return refer_spec, { + "ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]), + "ref_spec_ms": float(ref_spec_ms), + "audio_stage_slots": float(ref_spec_limiter_stats["slots"]), + "audio_stage_inflight_peak": float(ref_spec_limiter_stats["peak_inflight"]), + } + + ref_spec_task = asyncio.create_task(asyncio.to_thread(_build_ref_spec_profile)) + (prompt_semantic, prompt_semantic_profile), (refer_spec, ref_spec_profile) = await asyncio.gather( + prompt_semantic_task, + ref_spec_task, + ) + + prompt_semantic_ms = ( + float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) + ) + audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float( + ref_spec_profile.get("ref_spec_wait_ms", 0.0) + ) + audio_stage_slots = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + float(ref_spec_profile.get("audio_stage_slots", 0.0)), + ) + audio_stage_inflight_peak = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + float(ref_spec_profile.get("audio_stage_inflight_peak", 0.0)), + ) + return { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": float(load_ms), + "audio_stage_wait_ms": float(audio_stage_wait_ms), + "audio_stage_slots": float(audio_stage_slots), + "audio_stage_inflight_peak": float(audio_stage_inflight_peak), + "prompt_semantic_ms": float(prompt_semantic_ms), + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float(prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)), + "ref_spec_wait_ms": float(ref_spec_profile.get("ref_spec_wait_ms", 0.0)), + "ref_spec_ms": float(ref_spec_profile.get("ref_spec_ms", 0.0)), + "bundle_total_ms": float(load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_profile.get("ref_spec_ms", 0.0)), + }, + } + def extract_text_features(self, text: str, language: str, profile: dict | None = None): return self.text_preprocessor.segment_and_extract_feature_for_text( text, language, self.configs.version, profile=profile diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 6bee49be..01a8ea4d 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -118,6 +118,15 @@ class TextPreprocessor: self.bert_stage_limiter = bert_stage_limiter self.bert_batch_worker = bert_batch_worker + def snapshot(self) -> Dict[str, object]: + return { + "device": str(self.device), + "bert_stage_limiter": ( + None if self.bert_stage_limiter is None else dict(self.bert_stage_limiter.snapshot()) + ), + "bert_batch_worker": None if self.bert_batch_worker is None else dict(self.bert_batch_worker.snapshot()), + } + def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: print(f"############ {i18n('切分文本')} ############") text = self.replace_consecutive_punctuation(text) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py index 06a5e1b8..71134268 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py @@ -24,6 +24,7 @@ class ProfiledResult: submit_at: float started_at: float finished_at: float + profile: Dict[str, float] | None = None @property def queue_ms(self) -> float: @@ -48,6 +49,52 @@ class PreparedCpuStage: target_cpu_profiled: ProfiledResult +class AsyncStageGate: + def __init__(self, max_inflight: int, poll_ms: int = 1): + self.max_inflight = max(0, int(max_inflight)) + self.lock = threading.Lock() + self.poll_s = max(0.0005, float(max(1, int(poll_ms))) / 1000.0) + self.inflight = 0 + self.peak_inflight = 0 + self.total_entered = 0 + self.total_wait_ms = 0.0 + self.wait_peak_ms = 0.0 + + async def acquire(self) -> Dict[str, float]: + wait_start = time.perf_counter() + while True: + with self.lock: + if self.max_inflight <= 0 or self.inflight < self.max_inflight: + self.inflight += 1 + self.total_entered += 1 + wait_ms = max(0.0, (time.perf_counter() - wait_start) * 1000.0) + self.total_wait_ms += float(wait_ms) + self.wait_peak_ms = max(self.wait_peak_ms, float(wait_ms)) + self.peak_inflight = max(self.peak_inflight, self.inflight) + return { + "wait_ms": float(wait_ms), + "inflight": float(self.inflight), + "peak_inflight": float(self.peak_inflight), + "max_inflight": float(self.max_inflight), + } + await asyncio.sleep(self.poll_s) + + def release(self) -> None: + with self.lock: + self.inflight = max(0, self.inflight - 1) + + def snapshot(self) -> Dict[str, float]: + with self.lock: + return { + "max_inflight": float(self.max_inflight), + "inflight": float(self.inflight), + "peak_inflight": float(self.peak_inflight), + "total_entered": float(self.total_entered), + "total_wait_ms": float(self.total_wait_ms), + "wait_peak_ms": float(self.wait_peak_ms), + } + + class PrepareCoordinator: def __init__(self, tts: Any): self.tts = tts @@ -59,7 +106,8 @@ class PrepareCoordinator: and os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_DIRECT", "0") != "0" ) self.max_inflight = max(0, int(os.environ.get("GPTSOVITS_PREPARE_MAX_INFLIGHT", "0"))) - self._inflight_semaphore = asyncio.Semaphore(self.max_inflight) if self.max_inflight > 0 else None + gate_poll_ms = int(os.environ.get("GPTSOVITS_PREPARE_GATE_POLL_MS", "1")) + self._inflight_gate = AsyncStageGate(self.max_inflight, poll_ms=gate_poll_ms) self.text_feature_workers = 0 self.text_feature_executor = None if not self.use_async_text_feature_path: @@ -81,6 +129,29 @@ class PrepareCoordinator: max_workers=self.ref_audio_workers, thread_name_prefix="prepare-ref-audio", ) + text_cpu_gate_default = max(0, int(getattr(tts, "prepare_text_cpu_workers", 0) or 0)) + text_feature_gate_default = max(0, int(self.text_feature_workers)) + ref_audio_gate_default = max(0, int(self.ref_audio_workers)) + self.text_cpu_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_INFLIGHT", str(text_cpu_gate_default))), + poll_ms=gate_poll_ms, + ) + self.text_feature_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_MAX_INFLIGHT", str(text_feature_gate_default))), + poll_ms=gate_poll_ms, + ) + self.ref_audio_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_REF_MAX_INFLIGHT", str(ref_audio_gate_default))), + poll_ms=gate_poll_ms, + ) + self.ref_load_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_REF_LOAD_MAX_INFLIGHT", str(ref_audio_gate_default))), + poll_ms=gate_poll_ms, + ) + self.ref_spec_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_REF_SPEC_MAX_INFLIGHT", str(ref_audio_gate_default))), + poll_ms=gate_poll_ms, + ) def _mark_enter(self) -> Tuple[int, int]: with self.lock: @@ -94,15 +165,29 @@ class PrepareCoordinator: with self.lock: self.inflight = max(0, self.inflight - 1) - def snapshot(self) -> Dict[str, int]: + def snapshot(self) -> Dict[str, Any]: with self.lock: - return { + snapshot: Dict[str, Any] = { "inflight": int(self.inflight), "peak_inflight": int(self.peak_inflight), "max_inflight": int(self.max_inflight), "text_feature_workers": int(self.text_feature_workers), "ref_audio_workers": int(self.ref_audio_workers), } + runtime_snapshot_fn = getattr(self.tts, "snapshot_prepare_runtime_components", None) + if callable(runtime_snapshot_fn): + try: + snapshot["prepare_runtime_state"] = runtime_snapshot_fn() + except Exception: + snapshot["prepare_runtime_state"] = None + snapshot["prepare_stage_gates"] = { + "text_cpu": self.text_cpu_gate.snapshot(), + "text_feature": self.text_feature_gate.snapshot(), + "ref_audio": self.ref_audio_gate.snapshot(), + "ref_load": self.ref_load_gate.snapshot(), + "ref_spec": self.ref_spec_gate.snapshot(), + } + return snapshot @staticmethod def _run_profiled(fn, submit_at: float, *args) -> ProfiledResult: @@ -119,6 +204,12 @@ class PrepareCoordinator: def _prepare_text_cpu(self, text: str, language: str): return self.tts.prepare_text_segments(text, language) + def _load_ref_audio_raw(self, ref_audio_path: str): + return self.tts._load_ref_audio_raw(ref_audio_path) + + def _extract_ref_spec_from_raw(self, raw_audio, raw_sr: int): + return self.tts._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + @staticmethod def _build_empty_text_features_like(reference: PreparedTextFeatures | None = None) -> PreparedTextFeatures: feature_dim = 1024 @@ -155,17 +246,54 @@ class PrepareCoordinator: return await loop.run_in_executor(executor, self._run_profiled, fn, float(submit_at), *args) async def _run_text_cpu_stage(self, text: str, language: str) -> ProfiledResult: + await self.text_cpu_gate.acquire() if text in [None, ""]: - submit_at = time.perf_counter() - return ProfiledResult(result=[], submit_at=submit_at, started_at=submit_at, finished_at=submit_at) + try: + submit_at = time.perf_counter() + return ProfiledResult(result=[], submit_at=submit_at, started_at=submit_at, finished_at=submit_at) + finally: + self.text_cpu_gate.release() + text_cpu_worker = getattr(self.tts, "prepare_text_cpu_worker", None) executor = getattr(self.tts, "prepare_text_cpu_executor", None) - if executor is None: - submit_at = time.perf_counter() - return self._run_profiled(self._prepare_text_cpu, submit_at, text, language) - return await self._run_on_executor(executor, self._prepare_text_cpu, text, language) + try: + if text_cpu_worker is not None: + submit_at = time.perf_counter() + result, worker_profile = await text_cpu_worker.submit_async(text, language) + started_at = float( + submit_at + + ( + float(worker_profile.get("text_cpu_admission_wait_ms", 0.0)) + + float(worker_profile.get("text_cpu_queue_wait_ms", 0.0)) + ) + / 1000.0 + ) + finished_at = float(started_at + float(worker_profile.get("text_cpu_run_ms", 0.0)) / 1000.0) + return ProfiledResult( + result=result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=finished_at, + profile=dict(worker_profile), + ) + if executor is None: + submit_at = time.perf_counter() + return self._run_profiled(self._prepare_text_cpu, submit_at, text, language) + return await self._run_on_executor(executor, self._prepare_text_cpu, text, language) + finally: + self.text_cpu_gate.release() async def _run_text_feature_stage(self, prepared_segments, language: str, cpu_run_ms: float) -> ProfiledResult: - return await self._run_on_executor(self.text_feature_executor, self._build_text_features, prepared_segments, language, cpu_run_ms) + await self.text_feature_gate.acquire() + try: + return await self._run_on_executor( + self.text_feature_executor, + self._build_text_features, + prepared_segments, + language, + cpu_run_ms, + ) + finally: + self.text_feature_gate.release() @staticmethod def _estimate_text_feature_run_ms(profile: Dict[str, float]) -> float: @@ -199,16 +327,34 @@ class PrepareCoordinator: ) return prompt_profiled, target_profiled + await self.text_feature_gate.acquire() target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)} submit_at = time.perf_counter() started_at = float(submit_at) - if prompt_is_empty: - target_result_raw = await self.tts.build_text_features_from_segments_async( - target_segments, - profile=target_profile, - ) - prompt_result = self._build_empty_text_features_like( - PreparedTextFeatures( + try: + if prompt_is_empty: + target_result_raw = await self.tts.build_text_features_from_segments_async( + target_segments, + profile=target_profile, + ) + prompt_result = self._build_empty_text_features_like( + PreparedTextFeatures( + phones=target_result_raw[0], + bert_features=target_result_raw[1], + norm_text=target_result_raw[2], + profile=target_profile, + total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), + cpu_preprocess_ms=float(target_cpu_run_ms), + ) + ) + finished_at = time.perf_counter() + prompt_profiled = ProfiledResult( + result=prompt_result, + submit_at=float(submit_at), + started_at=float(submit_at), + finished_at=float(submit_at), + ) + target_result = PreparedTextFeatures( phones=target_result_raw[0], bert_features=target_result_raw[1], norm_text=target_result_raw[2], @@ -216,13 +362,37 @@ class PrepareCoordinator: total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), cpu_preprocess_ms=float(target_cpu_run_ms), ) + target_profiled = ProfiledResult( + result=target_result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0), + ) + if finished_at > target_profiled.finished_at: + target_result.profile["bert_total_ms"] = max( + self._estimate_text_feature_run_ms(target_profile), + (finished_at - submit_at) * 1000.0, + ) + else: + target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) + return prompt_profiled, target_profiled + + prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)} + prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async( + prompt_segments, + target_segments, + prompt_profile=prompt_profile, + target_profile=target_profile, ) finished_at = time.perf_counter() - prompt_profiled = ProfiledResult( - result=prompt_result, - submit_at=float(submit_at), - started_at=float(submit_at), - finished_at=float(submit_at), + + prompt_result = PreparedTextFeatures( + phones=prompt_result_raw[0], + bert_features=prompt_result_raw[1], + norm_text=prompt_result_raw[2], + profile=prompt_profile, + total_ms=float(prompt_cpu_run_ms + self._estimate_text_feature_run_ms(prompt_profile)), + cpu_preprocess_ms=float(prompt_cpu_run_ms), ) target_result = PreparedTextFeatures( phones=target_result_raw[0], @@ -232,79 +402,152 @@ class PrepareCoordinator: total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), cpu_preprocess_ms=float(target_cpu_run_ms), ) + prompt_profiled = ProfiledResult( + result=prompt_result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(submit_at + self._estimate_text_feature_run_ms(prompt_profile) / 1000.0), + ) target_profiled = ProfiledResult( result=target_result, submit_at=float(submit_at), started_at=started_at, finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0), ) - if finished_at > target_profiled.finished_at: + if finished_at > prompt_profiled.finished_at: + prompt_result.profile["bert_total_ms"] = max( + self._estimate_text_feature_run_ms(prompt_profile), + (finished_at - submit_at) * 1000.0, + ) target_result.profile["bert_total_ms"] = max( self._estimate_text_feature_run_ms(target_profile), (finished_at - submit_at) * 1000.0, ) else: + prompt_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(prompt_profile) target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) return prompt_profiled, target_profiled - - prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)} - prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async( - prompt_segments, - target_segments, - prompt_profile=prompt_profile, - target_profile=target_profile, - ) - finished_at = time.perf_counter() - - prompt_result = PreparedTextFeatures( - phones=prompt_result_raw[0], - bert_features=prompt_result_raw[1], - norm_text=prompt_result_raw[2], - profile=prompt_profile, - total_ms=float(prompt_cpu_run_ms + self._estimate_text_feature_run_ms(prompt_profile)), - cpu_preprocess_ms=float(prompt_cpu_run_ms), - ) - target_result = PreparedTextFeatures( - phones=target_result_raw[0], - bert_features=target_result_raw[1], - norm_text=target_result_raw[2], - profile=target_profile, - total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), - cpu_preprocess_ms=float(target_cpu_run_ms), - ) - prompt_profiled = ProfiledResult( - result=prompt_result, - submit_at=float(submit_at), - started_at=started_at, - finished_at=float(submit_at + self._estimate_text_feature_run_ms(prompt_profile) / 1000.0), - ) - target_profiled = ProfiledResult( - result=target_result, - submit_at=float(submit_at), - started_at=started_at, - finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0), - ) - if finished_at > prompt_profiled.finished_at: - prompt_result.profile["bert_total_ms"] = max( - self._estimate_text_feature_run_ms(prompt_profile), - (finished_at - submit_at) * 1000.0, - ) - target_result.profile["bert_total_ms"] = max( - self._estimate_text_feature_run_ms(target_profile), - (finished_at - submit_at) * 1000.0, - ) - else: - prompt_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(prompt_profile) - target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) - return prompt_profiled, target_profiled + finally: + self.text_feature_gate.release() async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult: - return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path) + if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None: + submit_at = time.perf_counter() + started_at = float(submit_at) + + await self.ref_load_gate.acquire() + try: + load_profiled = await self._run_on_executor(self.ref_audio_executor, self._load_ref_audio_raw, ref_audio_path) + finally: + self.ref_load_gate.release() + + raw_audio, raw_sr = load_profiled.result + prompt_semantic_task = asyncio.create_task( + self.tts.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr) + ) + await self.ref_spec_gate.acquire() + try: + ref_spec_task = asyncio.create_task( + self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr) + ) + (prompt_semantic, prompt_semantic_profile), ref_spec_profiled = await asyncio.gather( + prompt_semantic_task, + ref_spec_task, + ) + finally: + self.ref_spec_gate.release() + + refer_spec = ref_spec_profiled.result + limiter_snapshot = ( + self.tts.prepare_ref_audio_stage_limiter.snapshot() + if getattr(self.tts, "prepare_ref_audio_stage_limiter", None) is not None + else {} + ) + prompt_semantic_ms = ( + float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) + ) + audio_stage_wait_ms = ( + float(load_profiled.queue_ms) + + float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + + float(ref_spec_profiled.queue_ms) + ) + finished_at = time.perf_counter() + result = { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_queue_ms": float(load_profiled.queue_ms), + "audio_load_ms": float(load_profiled.run_ms), + "audio_stage_wait_ms": float(audio_stage_wait_ms), + "audio_stage_slots": float( + max( + float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + float(limiter_snapshot.get("slots", 0.0)), + ) + ), + "audio_stage_inflight_peak": float( + max( + float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + float(limiter_snapshot.get("peak_inflight", 0.0)), + ) + ), + "prompt_semantic_ms": float(prompt_semantic_ms), + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + ), + "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float( + prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) + ), + "ref_spec_wait_ms": float(ref_spec_profiled.queue_ms), + "ref_spec_ms": float(ref_spec_profiled.run_ms), + "bundle_total_ms": float( + load_profiled.queue_ms + + load_profiled.run_ms + + prompt_semantic_ms + + ref_spec_profiled.queue_ms + + ref_spec_profiled.run_ms + ), + }, + } + return ProfiledResult( + result=result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(finished_at), + ) + + await self.ref_audio_gate.acquire() + try: + if hasattr(self.tts, "extract_ref_audio_bundle_async"): + submit_at = time.perf_counter() + started_at = time.perf_counter() + result = await self.tts.extract_ref_audio_bundle_async(ref_audio_path) + finished_at = time.perf_counter() + return ProfiledResult( + result=result, + submit_at=float(submit_at), + started_at=float(started_at), + finished_at=float(finished_at), + ) + return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path) + finally: + self.ref_audio_gate.release() def _release_split_stage_slot(self) -> None: self._mark_leave() - if self._inflight_semaphore is not None: - self._inflight_semaphore.release() + self._inflight_gate.release() async def prepare_cpu_stage_profiled_async( self, @@ -312,9 +555,11 @@ class PrepareCoordinator: prepare_submit_at: float, ) -> PreparedCpuStage: admission_start = time.perf_counter() - if self._inflight_semaphore is not None: - await self._inflight_semaphore.acquire() - prepare_admission_wait_ms = max(0.0, (time.perf_counter() - admission_start) * 1000.0) + admission_stats = await self._inflight_gate.acquire() + prepare_admission_wait_ms = max( + float(admission_stats.get("wait_ms", 0.0)), + (time.perf_counter() - admission_start) * 1000.0, + ) current_inflight, peak_inflight = self._mark_enter() prepare_start = time.perf_counter() prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) @@ -382,10 +627,28 @@ class PrepareCoordinator: "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, "prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms, "prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms, + "prompt_text_cpu_admission_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) + ), + "prompt_text_cpu_backpressure_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) + ), + "prompt_text_cpu_capacity_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) + ), "prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms, "prompt_text_feature_run_ms": prompt_feature_profiled.run_ms, "text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms, "text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms, + "text_cpu_admission_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) + ), + "text_cpu_backpressure_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) + ), + "text_cpu_capacity_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) + ), "text_feature_queue_ms": target_feature_profiled.queue_ms, "text_feature_run_ms": target_feature_profiled.run_ms, "ref_audio_task_queue_ms": ref_audio_profiled.queue_ms, diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py index 7a1f9a53..64ca133f 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py @@ -1,3 +1,4 @@ +import asyncio import threading import time import uuid @@ -51,6 +52,8 @@ class RefSemanticTask: task_id: str = field(default_factory=lambda: uuid.uuid4().hex) created_at: float = field(default_factory=time.perf_counter) done_event: threading.Event = field(default_factory=threading.Event) + done_loop: asyncio.AbstractEventLoop | None = None + done_future: asyncio.Future | None = None result_prompt_semantic: torch.Tensor | None = None error: Exception | None = None profile: Dict[str, float] = field(default_factory=dict) @@ -115,6 +118,41 @@ class PrepareRefSemanticBatchWorker: assert task.result_prompt_semantic is not None return task.result_prompt_semantic, dict(task.profile) + async def submit_async(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]: + loop = asyncio.get_running_loop() + task = RefSemanticTask( + raw_audio=raw_audio, + raw_sr=int(raw_sr), + done_loop=loop, + done_future=loop.create_future(), + ) + with self.condition: + self.pending_tasks.append(task) + self.total_submitted += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + return await task.done_future + + @staticmethod + def _resolve_done_future(task: RefSemanticTask) -> None: + if task.done_future is None or task.done_future.done(): + return + if task.error is not None: + task.done_future.set_exception(task.error) + return + assert task.result_prompt_semantic is not None + task.done_future.set_result((task.result_prompt_semantic, dict(task.profile))) + + def _notify_task_done(self, task: RefSemanticTask) -> None: + task.done_event.set() + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_done_future, task) + except RuntimeError: + pass + def snapshot(self) -> Dict[str, int]: with self.condition: return { @@ -247,7 +285,7 @@ class PrepareRefSemanticBatchWorker: for task in batch: if task.result_prompt_semantic is not None: task.profile["prompt_semantic_scatter_ms"] = float(scatter_ms) - task.done_event.set() + self._notify_task_done(task) def _run_loop(self) -> None: while True: @@ -257,6 +295,6 @@ class PrepareRefSemanticBatchWorker: except Exception as exc: # noqa: PERF203 for task in batch: task.error = exc - task.done_event.set() + self._notify_task_done(task) finally: self._finalize_batch(batch) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py new file mode 100644 index 00000000..b7c985f0 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py @@ -0,0 +1,215 @@ +import asyncio +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Callable, Deque, Dict, Tuple + + +@dataclass +class TextCpuTask: + text: str + language: str + task_id: str = field(default_factory=lambda: uuid.uuid4().hex) + created_at: float = field(default_factory=time.perf_counter) + enqueued_at: float = 0.0 + admission_wait_ms: float = 0.0 + backpressure_wait_ms: float = 0.0 + capacity_wait_ms: float = 0.0 + pending_depth_on_enqueue: int = 0 + done_event: threading.Event = field(default_factory=threading.Event) + done_loop: asyncio.AbstractEventLoop | None = None + done_future: asyncio.Future | None = None + result: Any = None + error: Exception | None = None + profile: Dict[str, float] = field(default_factory=dict) + + +class PrepareTextCpuWorker: + def __init__( + self, + process_fn: Callable[[str, str], Any], + worker_count: int, + max_pending_tasks: int = 0, + admission_poll_ms: int = 1, + admission_controller: Callable[[], Dict[str, float | int | bool]] | None = None, + ) -> None: + self.process_fn = process_fn + self.worker_count = max(1, int(worker_count)) + self.max_pending_tasks = max(0, int(max_pending_tasks)) + self.admission_poll_s = max(0.0005, float(max(1, int(admission_poll_ms))) / 1000.0) + self.admission_controller = admission_controller + + self.condition = threading.Condition() + self.pending_tasks: Deque[TextCpuTask] = deque() + self.pending_peak = 0 + self.total_submitted = 0 + self.total_finished = 0 + self.active_workers = 0 + self.active_workers_peak = 0 + self.admission_wait_total_ms = 0.0 + self.admission_wait_peak_ms = 0.0 + self.backpressure_wait_total_ms = 0.0 + self.backpressure_wait_peak_ms = 0.0 + self.capacity_wait_total_ms = 0.0 + self.capacity_wait_peak_ms = 0.0 + self.backpressure_blocked_total = 0 + + self.worker_threads = [ + threading.Thread(target=self._run_loop, name=f"prepare-text-cpu-worker-{index}", daemon=True) + for index in range(self.worker_count) + ] + for thread in self.worker_threads: + thread.start() + + def _can_enqueue_locked(self) -> bool: + if self.max_pending_tasks <= 0: + return True + return (len(self.pending_tasks) + self.active_workers) < self.max_pending_tasks + + def _get_admission_state(self) -> Dict[str, float | int | bool]: + if self.admission_controller is None: + return {"blocked": False} + try: + state = dict(self.admission_controller() or {}) + except Exception: + return {"blocked": False} + state["blocked"] = bool(state.get("blocked", False)) + return state + + def _record_enqueue_locked( + self, + task: TextCpuTask, + *, + admission_wait_ms: float, + backpressure_wait_ms: float, + capacity_wait_ms: float, + ) -> None: + task.admission_wait_ms = float(max(0.0, admission_wait_ms)) + task.backpressure_wait_ms = float(max(0.0, backpressure_wait_ms)) + task.capacity_wait_ms = float(max(0.0, capacity_wait_ms)) + task.enqueued_at = time.perf_counter() + task.pending_depth_on_enqueue = int(len(self.pending_tasks)) + self.pending_tasks.append(task) + self.total_submitted += 1 + self.admission_wait_total_ms += task.admission_wait_ms + self.admission_wait_peak_ms = max(self.admission_wait_peak_ms, task.admission_wait_ms) + self.backpressure_wait_total_ms += task.backpressure_wait_ms + self.backpressure_wait_peak_ms = max(self.backpressure_wait_peak_ms, task.backpressure_wait_ms) + self.capacity_wait_total_ms += task.capacity_wait_ms + self.capacity_wait_peak_ms = max(self.capacity_wait_peak_ms, task.capacity_wait_ms) + if task.backpressure_wait_ms > 0.0: + self.backpressure_blocked_total += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + + async def _enqueue_task_async(self, task: TextCpuTask) -> None: + admission_started = time.perf_counter() + backpressure_wait_ms = 0.0 + capacity_wait_ms = 0.0 + while True: + loop_start = time.perf_counter() + admission_state = self._get_admission_state() + blocked = bool(admission_state.get("blocked", False)) + with self.condition: + if not blocked and self._can_enqueue_locked(): + self._record_enqueue_locked( + task, + admission_wait_ms=(time.perf_counter() - admission_started) * 1000.0, + backpressure_wait_ms=backpressure_wait_ms, + capacity_wait_ms=capacity_wait_ms, + ) + return + await asyncio.sleep(self.admission_poll_s) + waited_ms = (time.perf_counter() - loop_start) * 1000.0 + if blocked: + backpressure_wait_ms += waited_ms + else: + capacity_wait_ms += waited_ms + + def submit(self, text: str, language: str) -> Tuple[Any, Dict[str, float]]: + task = TextCpuTask(text=str(text), language=str(language)) + asyncio.run(self._enqueue_task_async(task)) + task.done_event.wait() + if task.error is not None: + raise task.error + return task.result, dict(task.profile) + + async def submit_async(self, text: str, language: str) -> Tuple[Any, Dict[str, float]]: + loop = asyncio.get_running_loop() + task = TextCpuTask( + text=str(text), + language=str(language), + done_loop=loop, + done_future=loop.create_future(), + ) + await self._enqueue_task_async(task) + return await task.done_future + + @staticmethod + def _resolve_done_future(task: TextCpuTask) -> None: + if task.done_future is None or task.done_future.done(): + return + if task.error is not None: + task.done_future.set_exception(task.error) + return + task.done_future.set_result((task.result, dict(task.profile))) + + def _notify_task_done(self, task: TextCpuTask) -> None: + task.done_event.set() + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_done_future, task) + except RuntimeError: + pass + + def snapshot(self) -> Dict[str, int | float]: + with self.condition: + return { + "worker_count": int(self.worker_count), + "pending": int(len(self.pending_tasks)), + "pending_peak": int(self.pending_peak), + "active_workers": int(self.active_workers), + "active_workers_peak": int(self.active_workers_peak), + "total_submitted": int(self.total_submitted), + "total_finished": int(self.total_finished), + "max_pending_tasks": int(self.max_pending_tasks), + "admission_wait_total_ms": float(self.admission_wait_total_ms), + "admission_wait_peak_ms": float(self.admission_wait_peak_ms), + "backpressure_wait_total_ms": float(self.backpressure_wait_total_ms), + "backpressure_wait_peak_ms": float(self.backpressure_wait_peak_ms), + "capacity_wait_total_ms": float(self.capacity_wait_total_ms), + "capacity_wait_peak_ms": float(self.capacity_wait_peak_ms), + "backpressure_blocked_total": int(self.backpressure_blocked_total), + } + + def _run_loop(self) -> None: + while True: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + task = self.pending_tasks.popleft() + self.active_workers += 1 + self.active_workers_peak = max(self.active_workers_peak, self.active_workers) + started_at = time.perf_counter() + try: + task.result = self.process_fn(task.text, task.language) + task.profile = { + "text_cpu_admission_wait_ms": float(task.admission_wait_ms), + "text_cpu_backpressure_wait_ms": float(task.backpressure_wait_ms), + "text_cpu_capacity_wait_ms": float(task.capacity_wait_ms), + "text_cpu_queue_wait_ms": max(0.0, (started_at - task.enqueued_at) * 1000.0), + "text_cpu_pending_depth_on_enqueue": float(task.pending_depth_on_enqueue), + "text_cpu_run_ms": max(0.0, (time.perf_counter() - started_at) * 1000.0), + } + except Exception as exc: # noqa: PERF203 + task.error = exc + finally: + with self.condition: + self.active_workers = max(0, self.active_workers - 1) + self.total_finished += 1 + self.condition.notify_all() + self._notify_task_done(task) diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index 78dbfc36..ed465f69 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -421,6 +421,55 @@ def _iter_contiguous_sampling_groups( return groups +def _uniform_sampling_group_key(active_batch: T2SActiveBatch) -> Optional[Tuple[int, float, float, float, bool]]: + if not active_batch.states: + return None + if active_batch.step_indices.numel() <= 0: + return None + first_step_index = int(active_batch.step_indices[0].item()) + if bool((active_batch.step_indices != first_step_index).any().item()): + return None + first_state = active_batch.states[0] + first_key = _sampling_group_key( + top_k=first_state.top_k, + top_p=first_state.top_p, + temperature=first_state.temperature, + repetition_penalty=first_state.repetition_penalty, + trim_eos=first_step_index < 11, + ) + for state in active_batch.states[1:]: + if ( + state.top_k != first_state.top_k + or state.top_p != first_state.top_p + or state.temperature != first_state.temperature + or state.repetition_penalty != first_state.repetition_penalty + ): + return None + return first_key + + +def _batched_sample_uniform( + logits: torch.Tensor, + histories: Sequence[torch.LongTensor], + sampling_key: Tuple[int, float, float, float, bool], +) -> Tuple[torch.Tensor, torch.Tensor]: + top_k, top_p, temperature, repetition_penalty, trim_eos = sampling_key + sample_logits = logits[:, :-1] if trim_eos else logits + padded_histories, history_mask = _pad_token_sequences(histories) + probs = logits_to_probs( + logits=sample_logits, + previous_tokens=padded_histories, + previous_token_mask=history_mask, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + sampled = multinomial_sample_one_no_sync(probs) + argmax_tokens = torch.argmax(sample_logits, dim=-1) + return sampled, argmax_tokens + + def _batched_sample_by_group( logits: torch.Tensor, histories: Sequence[torch.LongTensor], @@ -594,27 +643,59 @@ def _sample_per_request( keep_indices: List[int] = [] updated_sequences: List[torch.LongTensor] = [] - sampling_keys = [ - _sampling_group_key( - top_k=state.top_k, - top_p=state.top_p, - temperature=state.temperature, - repetition_penalty=state.repetition_penalty, - trim_eos=int(active_batch.step_indices[batch_index].item()) < 11, + uniform_sampling_key = _uniform_sampling_group_key(active_batch) + sampled_items: List[torch.Tensor] + argmax_tokens: List[int] + sampled_token_tensor: Optional[torch.Tensor] = None + argmax_token_tensor: Optional[torch.Tensor] = None + if uniform_sampling_key is not None: + sampled_tensor, argmax_tensor = _batched_sample_uniform( + logits=logits, + histories=active_batch.y_sequences, + sampling_key=uniform_sampling_key, + ) + sampled_token_tensor = sampled_tensor.view(-1) + argmax_token_tensor = argmax_tensor.view(-1) + if ( + all(state.early_stop_num == -1 for state in active_batch.states) + and int(active_batch.step_indices[0].item()) + 1 < max_steps + and not bool(sampled_token_tensor.eq(model.EOS).any().item()) + and not bool(argmax_token_tensor.eq(model.EOS).any().item()) + ): + return ( + [], + list(range(len(active_batch.states))), + [torch.cat([history, sampled_token_tensor[index : index + 1]], dim=0) for index, history in enumerate(active_batch.y_sequences)], + ) + sampled_items = [sampled_tensor[index : index + 1] for index in range(sampled_tensor.shape[0])] + argmax_tokens = [int(item) for item in argmax_tensor.tolist()] + else: + sampling_keys = [ + _sampling_group_key( + top_k=state.top_k, + top_p=state.top_p, + temperature=state.temperature, + repetition_penalty=state.repetition_penalty, + trim_eos=int(active_batch.step_indices[batch_index].item()) < 11, + ) + for batch_index, state in enumerate(active_batch.states) + ] + sampled_items, argmax_tokens = _batched_sample_by_group( + logits=logits, + histories=active_batch.y_sequences, + sampling_keys=sampling_keys, ) - for batch_index, state in enumerate(active_batch.states) - ] - sampled_items, argmax_tokens = _batched_sample_by_group( - logits=logits, - histories=active_batch.y_sequences, - sampling_keys=sampling_keys, - ) for batch_index, state in enumerate(active_batch.states): step_index = int(active_batch.step_indices[batch_index].item()) current_history = active_batch.y_sequences[batch_index] - sampled = sampled_items[batch_index] - sampled_token = int(sampled[0, 0].item()) - argmax_token = argmax_tokens[batch_index] + if sampled_token_tensor is not None and argmax_token_tensor is not None: + sampled = sampled_token_tensor[batch_index : batch_index + 1] + sampled_token = int(sampled_token_tensor[batch_index].item()) + argmax_token = int(argmax_token_tensor[batch_index].item()) + else: + sampled = sampled_items[batch_index] + sampled_token = int(sampled[0, 0].item()) + argmax_token = argmax_tokens[batch_index] new_history = torch.cat([current_history, sampled.view(-1)], dim=0) finish_reason: Optional[str] = None @@ -690,6 +771,13 @@ def decode_one_step( finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps) if len(keep_indices) == 0: return None, finished_items + if len(keep_indices) == len(active_batch.request_ids): + active_batch.y_sequences = updated_sequences + active_batch.step_indices = active_batch.step_indices + 1 + if not was_prefill and active_batch.kv_lens is not None: + active_batch.kv_lens = active_batch.kv_lens + 1 + active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences) + return active_batch, finished_items device = logits.device keep_tensor = torch.LongTensor(keep_indices).to(device) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py index f950c68d..e31c5dfe 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py @@ -314,6 +314,7 @@ def build_scheduler_submit_headers( "X-Prepare-Text-Pair-Wall-Ms": format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)), "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), "X-Prepare-Engine-GPU-Queue-Wait-Ms": format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)), + "X-Prepare-Engine-GPU-Batch-Size": str(int(prepare_profile.get("engine_gpu_prepare_batch_size", 0.0))), "X-Prepare-Audio-Load-Ms": format_ms_header(prepare_profile.get("audio_load_ms", 0.0)), "X-Prepare-Audio-Stage-Wait-Ms": format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)), "X-Prepare-Prompt-Semantic-Ms": format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)), diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py index 7f4e485f..15eedeca 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py @@ -44,6 +44,16 @@ class EngineTaskQueueOwner: return None return self.queue.popleft() + def pop_left_many(self, max_items: int) -> List[Any]: + limit = max(1, int(max_items)) + with self.condition: + if not self.queue: + return [] + selected: List[Any] = [] + while self.queue and len(selected) < limit: + selected.append(self.queue.popleft()) + return selected + def mark_completed(self, count: int = 1, *, notify: bool = False) -> None: if count <= 0: return @@ -315,6 +325,7 @@ class EngineGpuPrepareTask: engine_request_id: str | None enqueue_time: float queue_wait_ms: float = 0.0 + admission_wait_ms: float = 0.0 error: str | None = None diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py index bb3e8b06..b9095d2c 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import os import time from typing import Any @@ -9,6 +10,19 @@ from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineGpuPrepare class EnginePrepareStageMixin: + async def _wait_prepare_queue_admission(self) -> float: + soft_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_SOFT_MAX", "0"))) + if soft_max <= 0: + return 0.0 + poll_s = max( + 0.0005, + float(max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_ADMISSION_POLL_MS", "1")))) / 1000.0, + ) + wait_start = time.perf_counter() + while self.prepare_queue_owner.waiting_count() >= soft_max: + await asyncio.sleep(poll_s) + return max(0.0, (time.perf_counter() - wait_start) * 1000.0) + async def prepare_state_via_engine_gpu_queue( self, *, @@ -16,12 +30,14 @@ class EnginePrepareStageMixin: prepare_submit_at: float, engine_request_id: str | None, ) -> tuple[T2SRequestState, float, float]: + prepare_queue_admission_wait_ms = await self._wait_prepare_queue_admission() cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) if engine_request_id not in [None, ""]: self.update_request_state( str(engine_request_id), EngineStatus.GPU_PREPARING, { + "engine_prepare_queue_admission_wait_ms": float(prepare_queue_admission_wait_ms), "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), @@ -37,31 +53,44 @@ class EnginePrepareStageMixin: done_future=done_future, engine_request_id=engine_request_id or spec.request_id, enqueue_time=time.perf_counter(), + admission_wait_ms=float(prepare_queue_admission_wait_ms), ) self.prepare_queue_owner.enqueue(task) self.notify_arbiter() return await done_future def run_engine_prepare_once(self) -> bool: - task = self.prepare_queue_owner.pop_left() - if task is None: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + tasks = self.prepare_queue_owner.pop_left_many(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) + if not tasks: return False - queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) - try: - state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( - self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) - ) + now = time.perf_counter() + queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks] + batch_results = asyncio.run( + self.scheduler_worker.prepare_gpu_stages_profiled_async([task.cpu_stage for task in tasks]) + ) + completed_count = 0 + for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results): + if isinstance(result, Exception): + task.error = str(result) + self.fail_request_state(task.engine_request_id or task.request_id, str(result)) + self._notify_prepare_error(task, result) + completed_count += 1 + continue + state, prepare_exec_started_at, prepare_exec_finished_at = result + state.prepare_profile["engine_prepare_queue_admission_wait_ms"] = float(task.admission_wait_ms) state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) + state.prepare_profile["engine_gpu_prepare_batch_size"] = float(len(tasks)) if task.engine_request_id not in [None, ""]: self.merge_request_state_profile( str(task.engine_request_id), - {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, + { + "engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms), + "engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms), + "engine_gpu_prepare_batch_size": float(len(tasks)), + }, ) - self.prepare_queue_owner.mark_completed(1) self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) - return True - except Exception as exc: - task.error = str(exc) - self.fail_request_state(task.engine_request_id or task.request_id, str(exc)) - self._notify_prepare_error(task, exc) - return True + completed_count += 1 + self.prepare_queue_owner.mark_completed(completed_count) + return True diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py index 28da24ee..8b3db1fa 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import os import time from typing import Callable, Dict, List @@ -32,6 +33,11 @@ class WorkerPrepareExecutor: def get_max_inflight(self) -> int: return int(self.coordinator.snapshot().get("max_inflight", 0)) + def get_batch_policy(self) -> Dict[str, int]: + return { + "prepare_batch_max_items": max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_BATCH_MAX_ITEMS", 8))), + } + def is_idle(self) -> bool: return int(self.coordinator.snapshot().get("inflight", 0)) <= 0 @@ -69,3 +75,17 @@ class WorkerPrepareExecutor: return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) finally: self._notify_state_change() + + async def prepare_gpu_stages_profiled_async( + self, + cpu_stages: List[PreparedCpuStage], + ) -> List[tuple[T2SRequestState, float, float] | Exception]: + try: + return list( + await asyncio.gather( + *[self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) for cpu_stage in cpu_stages], + return_exceptions=True, + ) + ) + finally: + self._notify_state_change() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py index f1910409..e498e9ea 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py @@ -32,6 +32,9 @@ class WorkerSubmitLifecycleMixin: def get_finalize_batch_policy(self) -> Dict[str, Any]: return dict(self.finalize_executor.get_batch_policy()) + def get_prepare_batch_policy(self) -> Dict[str, int]: + return dict(self.prepare_executor.get_batch_policy()) + def get_decode_runtime_counters(self) -> Dict[str, int]: with self.condition: return self.decode_runtime_tracker.get_counters() @@ -258,3 +261,9 @@ class WorkerSubmitLifecycleMixin: cpu_stage: PreparedCpuStage, ) -> tuple[T2SRequestState, float, float]: return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage) + + async def prepare_gpu_stages_profiled_async( + self, + cpu_stages: List[PreparedCpuStage], + ) -> List[tuple[T2SRequestState, float, float] | Exception]: + return await self.prepare_executor.prepare_gpu_stages_profiled_async(cpu_stages) diff --git a/third_party/g2pw-cu b/third_party/g2pw-cu new file mode 160000 index 00000000..a53cf4ee --- /dev/null +++ b/third_party/g2pw-cu @@ -0,0 +1 @@ +Subproject commit a53cf4eed5759f7b5d4563ce6e4b13557e054d98