diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 92c829a1..81c1ca1e 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -529,6 +529,7 @@ class TTS: self.bert_model, self.bert_tokenizer, self.configs.device, + version=self.configs.version, bert_stage_limiter=self.prepare_bert_stage_limiter, bert_batch_worker=self.prepare_bert_batch_worker, ) @@ -558,6 +559,16 @@ class TTS: return None def snapshot_prepare_runtime_components(self) -> dict: + g2pw_runtime = None + try: + from text import chinese2 + + g2pw_instance = getattr(chinese2, "g2pw", None) + g2pw_backend = None if g2pw_instance is None else getattr(g2pw_instance, "_g2pw", None) + if g2pw_backend is not None and hasattr(g2pw_backend, "snapshot"): + g2pw_runtime = dict(g2pw_backend.snapshot()) + except Exception: + g2pw_runtime = None return { "text_cpu": { "workers": int(self.prepare_text_cpu_workers), @@ -587,6 +598,7 @@ class TTS: "text_preprocessor": ( None if self.text_preprocessor is None or not hasattr(self.text_preprocessor, "snapshot") else self.text_preprocessor.snapshot() ), + "g2pw": g2pw_runtime, } def _build_text_cpu_admission_state(self) -> dict: @@ -1204,6 +1216,9 @@ class TTS: def prepare_text_segments(self, text: str, language: str): return self.text_preprocessor.preprocess_text_segments(text, language, self.configs.version) + def resolve_g2pw_segments(self, prepared_segments, profile: dict | None = None): + return self.text_preprocessor.resolve_g2pw_segments(prepared_segments, profile=profile) + def build_text_features_from_segments(self, prepared_segments, profile: dict | None = None): return self.text_preprocessor.build_phones_and_bert_from_segments(prepared_segments, profile=profile) diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 01a8ea4d..c30e3195 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -101,6 +101,7 @@ class PreparedTextSegment: phones: List[int] word2ph: Optional[List[int]] norm_text: str + needs_g2pw: bool = False class TextPreprocessor: @@ -109,12 +110,14 @@ class TextPreprocessor: bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device, + version: str = "v2", bert_stage_limiter: StageLimiter | None = None, bert_batch_worker: PrepareBertBatchWorker | None = None, ): self.bert_model = bert_model self.tokenizer = tokenizer self.device = device + self.version = str(version) self.bert_stage_limiter = bert_stage_limiter self.bert_batch_worker = bert_batch_worker @@ -261,15 +264,66 @@ class TextPreprocessor: phones=list(payload["phones"]), word2ph=None if payload["word2ph"] is None else list(payload["word2ph"]), norm_text=str(payload["norm_text"]), + needs_g2pw=bool(payload.get("needs_g2pw", False)), ) for payload in payloads ] + def resolve_g2pw_segments( + self, + prepared_segments: List[PreparedTextSegment], + profile: Dict | None = None, + ) -> List[PreparedTextSegment]: + zh_indices = [index for index, segment in enumerate(prepared_segments) if bool(segment.needs_g2pw)] + if not zh_indices: + return prepared_segments + from text import chinese2 + + normalized_segments = [prepared_segments[index].norm_text for index in zh_indices] + resolved_segments, g2pw_profile = chinese2.g2p_segments(normalized_segments, return_profile=True) + self._accumulate_profile(profile, "g2pw_prepare_ms", g2pw_profile.get("g2pw_prepare_ms", 0.0)) + self._accumulate_profile(profile, "g2pw_predict_ms", g2pw_profile.get("g2pw_predict_ms", 0.0)) + self._accumulate_profile(profile, "g2pw_post_ms", g2pw_profile.get("g2pw_post_ms", 0.0)) + self._accumulate_profile(profile, "g2pw_total_ms", g2pw_profile.get("g2pw_total_ms", 0.0)) + self._accumulate_profile(profile, "g2pw_runtime_total_ms", g2pw_profile.get("g2pw_runtime_total_ms", 0.0)) + self._accumulate_profile(profile, "g2pw_runtime_queue_wait_ms", g2pw_profile.get("g2pw_runtime_queue_wait_ms", 0.0)) + self._accumulate_profile( + profile, + "g2pw_runtime_collect_wait_ms", + g2pw_profile.get("g2pw_runtime_collect_wait_ms", 0.0), + ) + self._accumulate_profile(profile, "g2pw_runtime_run_ms", g2pw_profile.get("g2pw_runtime_run_ms", 0.0)) + self._update_profile_peak( + profile, + "g2pw_runtime_batch_rows_peak", + g2pw_profile.get("g2pw_runtime_batch_rows", 0.0), + ) + self._update_profile_peak( + profile, + "g2pw_runtime_batch_requests_peak", + g2pw_profile.get("g2pw_runtime_batch_requests", 0.0), + ) + self._update_profile_peak( + profile, + "g2pw_runtime_pool_workers", + g2pw_profile.get("g2pw_runtime_pool_workers", 0.0), + ) + for index, (phones, word2ph, norm_text) in zip(zh_indices, resolved_segments): + prepared_segments[index] = PreparedTextSegment( + language=prepared_segments[index].language, + phones=list(cleaned_text_to_sequence(phones, self.version)), + word2ph=None if word2ph is None else list(word2ph), + norm_text=str(norm_text), + needs_g2pw=False, + ) + return prepared_segments + def build_phones_and_bert_from_segments( self, prepared_segments: List[PreparedTextSegment], profile: Dict | None = None, ) -> Tuple[list, torch.Tensor, str]: + prepared_segments = self.resolve_g2pw_segments(prepared_segments, profile=profile) phones_list: List[List[int]] = [] bert_list: List[torch.Tensor] = [] norm_text_list: List[str] = [] @@ -402,6 +456,7 @@ class TextPreprocessor: prepared_segments: List[PreparedTextSegment], profile: Dict | None = None, ) -> Tuple[list, torch.Tensor, str]: + prepared_segments = self.resolve_g2pw_segments(prepared_segments, profile=profile) segment_jobs = self._build_async_segment_jobs(prepared_segments, profile) pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = [] for segment_index, segment in enumerate(prepared_segments): @@ -473,6 +528,8 @@ class TextPreprocessor: prompt_profile: Dict | None = None, target_profile: Dict | None = None, ) -> Tuple[Tuple[list, torch.Tensor, str], Tuple[list, torch.Tensor, str]]: + prompt_segments = self.resolve_g2pw_segments(prompt_segments, profile=prompt_profile) + target_segments = self.resolve_g2pw_segments(target_segments, profile=target_profile) prompt_jobs = self._build_async_segment_jobs(prompt_segments, prompt_profile) target_jobs = self._build_async_segment_jobs(target_segments, target_profile) pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = [] diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py index 71134268..65bcbb51 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py @@ -120,6 +120,15 @@ class PrepareCoordinator: max_workers=self.text_feature_workers, thread_name_prefix="prepare-text-feature", ) + g2pw_default_workers = max(8, int(getattr(tts, "prepare_text_cpu_workers", 8) or 8)) + self.g2pw_workers = max( + 1, + int(os.environ.get("GPTSOVITS_PREPARE_G2PW_WORKERS", str(g2pw_default_workers))), + ) + self.g2pw_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=self.g2pw_workers, + thread_name_prefix="prepare-g2pw", + ) ref_audio_default_workers = max(1, int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4"))) self.ref_audio_workers = max( 1, @@ -130,12 +139,17 @@ class PrepareCoordinator: thread_name_prefix="prepare-ref-audio", ) text_cpu_gate_default = max(0, int(getattr(tts, "prepare_text_cpu_workers", 0) or 0)) + g2pw_gate_default = max(0, int(self.g2pw_workers)) text_feature_gate_default = max(0, int(self.text_feature_workers)) ref_audio_gate_default = max(0, int(self.ref_audio_workers)) self.text_cpu_gate = AsyncStageGate( int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_INFLIGHT", str(text_cpu_gate_default))), poll_ms=gate_poll_ms, ) + self.g2pw_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_G2PW_MAX_INFLIGHT", str(g2pw_gate_default))), + poll_ms=gate_poll_ms, + ) self.text_feature_gate = AsyncStageGate( int(os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_MAX_INFLIGHT", str(text_feature_gate_default))), poll_ms=gate_poll_ms, @@ -172,6 +186,7 @@ class PrepareCoordinator: "peak_inflight": int(self.peak_inflight), "max_inflight": int(self.max_inflight), "text_feature_workers": int(self.text_feature_workers), + "g2pw_workers": int(self.g2pw_workers), "ref_audio_workers": int(self.ref_audio_workers), } runtime_snapshot_fn = getattr(self.tts, "snapshot_prepare_runtime_components", None) @@ -182,6 +197,7 @@ class PrepareCoordinator: snapshot["prepare_runtime_state"] = None snapshot["prepare_stage_gates"] = { "text_cpu": self.text_cpu_gate.snapshot(), + "g2pw": self.g2pw_gate.snapshot(), "text_feature": self.text_feature_gate.snapshot(), "ref_audio": self.ref_audio_gate.snapshot(), "ref_load": self.ref_load_gate.snapshot(), @@ -204,6 +220,11 @@ class PrepareCoordinator: def _prepare_text_cpu(self, text: str, language: str): return self.tts.prepare_text_segments(text, language) + def _resolve_g2pw_segments(self, prepared_segments): + profile: Dict[str, float] = {} + resolved_segments = self.tts.resolve_g2pw_segments(prepared_segments, profile=profile) + return resolved_segments, profile + def _load_ref_audio_raw(self, ref_audio_path: str): return self.tts._load_ref_audio_raw(ref_audio_path) @@ -225,8 +246,15 @@ class PrepareCoordinator: dtype=(dtype if dtype is not None else None) or __import__("torch").float32, ) - def _build_text_features(self, prepared_segments, language: str, cpu_run_ms: float) -> PreparedTextFeatures: - profile: Dict[str, float] = {"cpu_preprocess_ms": float(cpu_run_ms)} + def _build_text_features( + self, + prepared_segments, + language: str, + cpu_run_ms: float, + base_profile: Dict[str, float] | None = None, + ) -> PreparedTextFeatures: + profile: Dict[str, float] = dict(base_profile or {}) + profile["cpu_preprocess_ms"] = float(cpu_run_ms) branch_start = time.perf_counter() phones, bert_features, norm_text = self.tts.build_text_features_from_segments(prepared_segments, profile=profile) total_ms = float(cpu_run_ms + (time.perf_counter() - branch_start) * 1000.0) @@ -291,10 +319,53 @@ class PrepareCoordinator: prepared_segments, language, cpu_run_ms, + None, ) finally: self.text_feature_gate.release() + async def _run_g2pw_stage(self, prepared_segments) -> ProfiledResult: + has_pending = any(bool(getattr(segment, "needs_g2pw", False)) for segment in (prepared_segments or [])) + if not has_pending: + submit_at = time.perf_counter() + return ProfiledResult( + result=prepared_segments, + submit_at=float(submit_at), + started_at=float(submit_at), + finished_at=float(submit_at), + profile={}, + ) + await self.g2pw_gate.acquire() + try: + profiled = await self._run_on_executor(self.g2pw_executor, self._resolve_g2pw_segments, prepared_segments) + result, stage_profile = profiled.result + return ProfiledResult( + result=result, + submit_at=float(profiled.submit_at), + started_at=float(profiled.started_at), + finished_at=float(profiled.finished_at), + profile=dict(stage_profile), + ) + finally: + self.g2pw_gate.release() + + async def _run_g2pw_pair_stage(self, prompt_segments, target_segments) -> tuple[ProfiledResult, ProfiledResult]: + prompt_is_empty = len(prompt_segments or []) == 0 + target_task = asyncio.create_task(self._run_g2pw_stage(target_segments)) + if not prompt_is_empty: + prompt_task = asyncio.create_task(self._run_g2pw_stage(prompt_segments)) + return await asyncio.gather(prompt_task, target_task) + target_profiled = await target_task + submit_at = time.perf_counter() + prompt_profiled = ProfiledResult( + result=prompt_segments, + submit_at=float(submit_at), + started_at=float(submit_at), + finished_at=float(submit_at), + profile={}, + ) + return prompt_profiled, target_profiled + @staticmethod def _estimate_text_feature_run_ms(profile: Dict[str, float]) -> float: return float( @@ -310,12 +381,32 @@ class PrepareCoordinator: target_segments, prompt_cpu_run_ms: float, target_cpu_run_ms: float, + prompt_base_profile: Dict[str, float] | None = None, + target_base_profile: Dict[str, float] | None = None, ) -> tuple[ProfiledResult, ProfiledResult]: prompt_is_empty = len(prompt_segments or []) == 0 if self.text_feature_executor is not None: - target_feature_task = asyncio.create_task(self._run_text_feature_stage(target_segments, None, target_cpu_run_ms)) + target_feature_task = asyncio.create_task( + self._run_on_executor( + self.text_feature_executor, + self._build_text_features, + target_segments, + None, + target_cpu_run_ms, + target_base_profile, + ) + ) if not prompt_is_empty: - prompt_feature_task = asyncio.create_task(self._run_text_feature_stage(prompt_segments, None, prompt_cpu_run_ms)) + prompt_feature_task = asyncio.create_task( + self._run_on_executor( + self.text_feature_executor, + self._build_text_features, + prompt_segments, + None, + prompt_cpu_run_ms, + prompt_base_profile, + ) + ) return await asyncio.gather(prompt_feature_task, target_feature_task) target_profiled = await target_feature_task submit_at = time.perf_counter() @@ -328,7 +419,8 @@ class PrepareCoordinator: return prompt_profiled, target_profiled await self.text_feature_gate.acquire() - target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)} + target_profile: Dict[str, float] = dict(target_base_profile or {}) + target_profile["cpu_preprocess_ms"] = float(target_cpu_run_ms) submit_at = time.perf_counter() started_at = float(submit_at) try: @@ -377,7 +469,8 @@ class PrepareCoordinator: target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) return prompt_profiled, target_profiled - prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)} + prompt_profile: Dict[str, float] = dict(prompt_base_profile or {}) + prompt_profile["cpu_preprocess_ms"] = float(prompt_cpu_run_ms) prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async( prompt_segments, target_segments, @@ -589,20 +682,31 @@ class PrepareCoordinator: cpu_stage: PreparedCpuStage, ) -> tuple[T2SRequestState, float, float]: try: - text_pair_start = time.perf_counter() - ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(cpu_stage.spec.ref_audio_path))) - text_feature_pair_task = asyncio.create_task( - self._run_text_feature_pair_stage( + g2pw_pair_start = time.perf_counter() + g2pw_pair_task = asyncio.create_task( + self._run_g2pw_pair_stage( cpu_stage.prompt_cpu_profiled.result, cpu_stage.target_cpu_profiled.result, - cpu_stage.prompt_cpu_profiled.run_ms, - cpu_stage.target_cpu_profiled.run_ms, ) ) - (prompt_feature_profiled, target_feature_profiled), ref_audio_profiled = await asyncio.gather( - text_feature_pair_task, + ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(cpu_stage.spec.ref_audio_path))) + (prompt_g2pw_profiled, target_g2pw_profiled), ref_audio_profiled = await asyncio.gather( + g2pw_pair_task, ref_audio_task, ) + g2pw_pair_end = time.perf_counter() + text_pair_start = time.perf_counter() + text_feature_pair_task = asyncio.create_task( + self._run_text_feature_pair_stage( + prompt_g2pw_profiled.result, + target_g2pw_profiled.result, + cpu_stage.prompt_cpu_profiled.run_ms, + cpu_stage.target_cpu_profiled.run_ms, + prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}), + target_base_profile=dict(target_g2pw_profiled.profile or {}), + ) + ) + prompt_feature_profiled, target_feature_profiled = await text_feature_pair_task text_pair_end = time.perf_counter() state = build_request_state_from_parts( tts=self.tts, @@ -619,6 +723,17 @@ class PrepareCoordinator: "prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms, "executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0), "text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0), + "g2pw_pair_ms": max(0.0, (g2pw_pair_end - g2pw_pair_start) * 1000.0), + "prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms, + "prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms, + "prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), + "prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), + "prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), + "text_g2pw_queue_ms": target_g2pw_profiled.queue_ms, + "text_g2pw_run_ms": target_g2pw_profiled.run_ms, + "text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), + "text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), + "text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), "prompt_text_parallel_future_wait_ms": 0.0, "prompt_text_parallel_future_executor_queue_ms": 0.0, "prompt_text_parallel_future_run_ms": 0.0, diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index ed465f69..e993a1ef 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -241,6 +241,24 @@ def build_request_state_from_parts( prompt_result.profile.get("bert_high_pressure_mode_peak", 0.0) ), "prompt_text_bert_batch_window_ms": float(prompt_result.profile.get("bert_batch_window_ms", 0.0)), + "prompt_text_g2pw_total_ms": float(prompt_result.profile.get("g2pw_total_ms", 0.0)), + "prompt_text_g2pw_prepare_ms": float(prompt_result.profile.get("g2pw_prepare_ms", 0.0)), + "prompt_text_g2pw_predict_ms": float(prompt_result.profile.get("g2pw_predict_ms", 0.0)), + "prompt_text_g2pw_post_ms": float(prompt_result.profile.get("g2pw_post_ms", 0.0)), + "prompt_text_g2pw_runtime_total_ms": float(prompt_result.profile.get("g2pw_runtime_total_ms", 0.0)), + "prompt_text_g2pw_runtime_queue_wait_ms": float( + prompt_result.profile.get("g2pw_runtime_queue_wait_ms", 0.0) + ), + "prompt_text_g2pw_runtime_collect_wait_ms": float( + prompt_result.profile.get("g2pw_runtime_collect_wait_ms", 0.0) + ), + "prompt_text_g2pw_runtime_run_ms": float(prompt_result.profile.get("g2pw_runtime_run_ms", 0.0)), + "prompt_text_g2pw_runtime_batch_rows_peak": float( + prompt_result.profile.get("g2pw_runtime_batch_rows_peak", 0.0) + ), + "prompt_text_g2pw_runtime_batch_requests_peak": float( + prompt_result.profile.get("g2pw_runtime_batch_requests_peak", 0.0) + ), "prompt_text_parallel_future_wait_ms": 0.0, "prompt_text_parallel_future_executor_queue_ms": 0.0, "prompt_text_parallel_future_run_ms": float(prompt_result.total_ms), @@ -267,6 +285,18 @@ def build_request_state_from_parts( ), "text_bert_high_pressure_mode_peak": float(target_result.profile.get("bert_high_pressure_mode_peak", 0.0)), "text_bert_batch_window_ms": float(target_result.profile.get("bert_batch_window_ms", 0.0)), + "text_g2pw_total_ms": float(target_result.profile.get("g2pw_total_ms", 0.0)), + "text_g2pw_prepare_ms": float(target_result.profile.get("g2pw_prepare_ms", 0.0)), + "text_g2pw_predict_ms": float(target_result.profile.get("g2pw_predict_ms", 0.0)), + "text_g2pw_post_ms": float(target_result.profile.get("g2pw_post_ms", 0.0)), + "text_g2pw_runtime_total_ms": float(target_result.profile.get("g2pw_runtime_total_ms", 0.0)), + "text_g2pw_runtime_queue_wait_ms": float(target_result.profile.get("g2pw_runtime_queue_wait_ms", 0.0)), + "text_g2pw_runtime_collect_wait_ms": float(target_result.profile.get("g2pw_runtime_collect_wait_ms", 0.0)), + "text_g2pw_runtime_run_ms": float(target_result.profile.get("g2pw_runtime_run_ms", 0.0)), + "text_g2pw_runtime_batch_rows_peak": float(target_result.profile.get("g2pw_runtime_batch_rows_peak", 0.0)), + "text_g2pw_runtime_batch_requests_peak": float( + target_result.profile.get("g2pw_runtime_batch_requests_peak", 0.0) + ), "text_feature_pair_ms": float(max(prompt_result.total_ms, target_result.total_ms)), "text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)), "audio_load_ms": audio_load_ms, diff --git a/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py b/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py index e2398251..3d5b2de5 100644 --- a/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py +++ b/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py @@ -8,6 +8,7 @@ sys.path.append(now_dir) from text.LangSegmenter import LangSegmenter from text import cleaned_text_to_sequence +from text import chinese2 from text.cleaner import clean_text @@ -83,16 +84,27 @@ def preprocess_text_segments_payload( payloads: List[PreparedTextSegmentPayload] = [] total_phones_len = 0 for segment_text, segment_lang in zip(textlist, langlist): - phones, word2ph, norm_text = clean_text_segment(segment_text, segment_lang, version) + normalized_language = segment_lang.replace("all_", "") + if normalized_language == "zh": + norm_text = chinese2.text_normalize(segment_text) + phones = [] + word2ph = None + needs_g2pw = True + estimated_phones_len = max(0, len(norm_text) * 2) + else: + phones, word2ph, norm_text = clean_text_segment(segment_text, segment_lang, version) + needs_g2pw = False + estimated_phones_len = len(phones) payloads.append( { - "language": segment_lang.replace("all_", ""), + "language": normalized_language, "phones": phones, "word2ph": word2ph, "norm_text": norm_text, + "needs_g2pw": needs_g2pw, } ) - total_phones_len += len(phones) + total_phones_len += int(estimated_phones_len) if not final and total_phones_len < 6: return preprocess_text_segments_payload("." + text, language, version, final=True) diff --git a/GPT_SoVITS/text/chinese2.py b/GPT_SoVITS/text/chinese2.py index acfebfe2..a5d32490 100644 --- a/GPT_SoVITS/text/chinese2.py +++ b/GPT_SoVITS/text/chinese2.py @@ -1,5 +1,6 @@ import os import re +import time import cn2an from pypinyin import lazy_pinyin, Style @@ -77,6 +78,205 @@ def g2p(text): return phones, word2ph +def _prepare_g2p_segments(segments): + prepared_segments = [] + batch_inputs = [] + for segment in segments: + processed_segment = re.sub("[a-zA-Z]+", "", segment) + seg_cut = psg.lcut(processed_segment) + seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) + prepared_segments.append( + { + "segment": processed_segment, + "seg_cut": seg_cut, + } + ) + if processed_segment: + batch_inputs.append(processed_segment) + return prepared_segments, batch_inputs + + +def _build_segment_from_g2pw(segment: str, seg_cut, pinyins): + phones_list = [] + word2ph = [] + initials = [] + finals = [] + pre_word_length = 0 + for word, pos in seg_cut: + sub_initials = [] + sub_finals = [] + now_word_length = pre_word_length + len(word) + + if pos == "eng": + pre_word_length = now_word_length + continue + + word_pinyins = pinyins[pre_word_length:now_word_length] + word_pinyins = correct_pronunciation(word, word_pinyins) + + for pinyin in word_pinyins: + if pinyin[0].isalpha(): + sub_initials.append(to_initials(pinyin)) + sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True)) + else: + sub_initials.append(pinyin) + sub_finals.append(pinyin) + + pre_word_length = now_word_length + sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) + sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos) + initials.append(sub_initials) + finals.append(sub_finals) + + initials = sum(initials, []) + finals = sum(finals, []) + for c, v in zip(initials, finals): + raw_pinyin = c + v + if c == v: + assert c in punctuation + phone = [c] + word2ph.append(1) + else: + v_without_tone = v[:-1] + tone = v[-1] + + pinyin = c + v_without_tone + assert tone in "12345" + + if c: + v_rep_map = { + "uei": "ui", + "iou": "iu", + "uen": "un", + } + if v_without_tone in v_rep_map.keys(): + pinyin = c + v_rep_map[v_without_tone] + else: + pinyin_rep_map = { + "ing": "ying", + "i": "yi", + "in": "yin", + "u": "wu", + } + if pinyin in pinyin_rep_map.keys(): + pinyin = pinyin_rep_map[pinyin] + else: + single_rep_map = { + "v": "yu", + "e": "e", + "i": "y", + "u": "w", + } + if pinyin[0] in single_rep_map.keys(): + pinyin = single_rep_map[pinyin[0]] + pinyin[1:] + + assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, segment, raw_pinyin) + new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ") + new_v = new_v + tone + phone = [new_c, new_v] + word2ph.append(len(phone)) + + phones_list += phone + return phones_list, word2ph + + +def _build_segment_without_g2pw(segment: str, seg_cut): + initials = [] + finals = [] + for word, pos in seg_cut: + if pos == "eng": + continue + sub_initials, sub_finals = _get_initials_finals(word) + sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) + sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos) + initials.append(sub_initials) + finals.append(sub_finals) + phones_list = [] + word2ph = [] + for c, v in zip(sum(initials, []), sum(finals, [])): + raw_pinyin = c + v + if c == v: + assert c in punctuation + phone = [c] + word2ph.append(1) + else: + v_without_tone = v[:-1] + tone = v[-1] + pinyin = c + v_without_tone + assert tone in "12345" + if c: + v_rep_map = {"uei": "ui", "iou": "iu", "uen": "un"} + if v_without_tone in v_rep_map: + pinyin = c + v_rep_map[v_without_tone] + else: + pinyin_rep_map = {"ing": "ying", "i": "yi", "in": "yin", "u": "wu"} + if pinyin in pinyin_rep_map: + pinyin = pinyin_rep_map[pinyin] + else: + single_rep_map = {"v": "yu", "e": "e", "i": "y", "u": "w"} + if pinyin[0] in single_rep_map: + pinyin = single_rep_map[pinyin[0]] + pinyin[1:] + assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, segment, raw_pinyin) + new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ") + new_v = new_v + tone + phone = [new_c, new_v] + word2ph.append(len(phone)) + phones_list += phone + return phones_list, word2ph + + +def g2p_segments(segments, return_profile: bool = False): + prepare_start = time.perf_counter() + prepared_segments, batch_inputs = _prepare_g2p_segments(segments) + profile = { + "g2pw_prepare_ms": 0.0, + "g2pw_predict_ms": 0.0, + "g2pw_post_ms": 0.0, + "g2pw_runtime_total_ms": 0.0, + "g2pw_runtime_queue_wait_ms": 0.0, + "g2pw_runtime_collect_wait_ms": 0.0, + "g2pw_runtime_run_ms": 0.0, + "g2pw_runtime_batch_rows": 0.0, + "g2pw_runtime_batch_requests": 0.0, + "g2pw_runtime_pool_workers": 0.0, + "g2pw_runtime_shard_index": 0.0, + } + profile["g2pw_prepare_ms"] = float((time.perf_counter() - prepare_start) * 1000.0) + if is_g2pw and batch_inputs: + converter = g2pw._g2pw + if hasattr(converter, "predict_sentences_with_profile"): + g2pw_batch_results, predict_profile = converter.predict_sentences_with_profile(batch_inputs) + for key, value in dict(predict_profile or {}).items(): + profile[key] = float(value) + else: + predict_start = time.perf_counter() + g2pw_batch_results = converter(batch_inputs) + profile["g2pw_predict_ms"] = float((time.perf_counter() - predict_start) * 1000.0) + else: + g2pw_batch_results = [] + post_start = time.perf_counter() + results = [] + batch_cursor = 0 + for item in prepared_segments: + segment = item["segment"] + if not segment: + results.append(([], [], segment)) + continue + if not is_g2pw: + phones, word2ph = _build_segment_without_g2pw(segment, item["seg_cut"]) + results.append((phones, word2ph, segment)) + continue + pinyins = g2pw_batch_results[batch_cursor] + batch_cursor += 1 + phones, word2ph = _build_segment_from_g2pw(segment, item["seg_cut"], pinyins) + results.append((phones, word2ph, segment)) + profile["g2pw_post_ms"] = float((time.perf_counter() - post_start) * 1000.0) + profile["g2pw_total_ms"] = float(profile["g2pw_prepare_ms"] + profile["g2pw_predict_ms"] + profile["g2pw_post_ms"]) + if return_profile: + return results, profile + return results + + def _get_initials_finals(word): initials = [] finals = [] @@ -180,125 +380,9 @@ def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) -> def _g2p(segments): phones_list = [] word2ph = [] - g2pw_batch_results = [] - g2pw_batch_cursor = 0 - processed_segments = [re.sub("[a-zA-Z]+", "", seg) for seg in segments] - if is_g2pw: - batch_inputs = [seg for seg in processed_segments if seg] - g2pw_batch_results = g2pw._g2pw(batch_inputs) if batch_inputs else [] - - for seg in processed_segments: - pinyins = [] - seg_cut = psg.lcut(seg) - seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) - initials = [] - finals = [] - - if not is_g2pw: - for word, pos in seg_cut: - if pos == "eng": - continue - sub_initials, sub_finals = _get_initials_finals(word) - sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) - # 儿化 - sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos) - initials.append(sub_initials) - finals.append(sub_finals) - # assert len(sub_initials) == len(sub_finals) == len(word) - initials = sum(initials, []) - finals = sum(finals, []) - print("pypinyin结果", initials, finals) - else: - # g2pw采用整句推理(批量推理,逐句取结果) - if seg: - pinyins = g2pw_batch_results[g2pw_batch_cursor] - g2pw_batch_cursor += 1 - - pre_word_length = 0 - for word, pos in seg_cut: - sub_initials = [] - sub_finals = [] - now_word_length = pre_word_length + len(word) - - if pos == "eng": - pre_word_length = now_word_length - continue - - word_pinyins = pinyins[pre_word_length:now_word_length] - - # 多音字消歧 - word_pinyins = correct_pronunciation(word, word_pinyins) - - for pinyin in word_pinyins: - if pinyin[0].isalpha(): - sub_initials.append(to_initials(pinyin)) - sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True)) - else: - sub_initials.append(pinyin) - sub_finals.append(pinyin) - - pre_word_length = now_word_length - sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) - # 儿化 - sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos) - initials.append(sub_initials) - finals.append(sub_finals) - - initials = sum(initials, []) - finals = sum(finals, []) - # print("g2pw结果",initials,finals) - - for c, v in zip(initials, finals): - raw_pinyin = c + v - # NOTE: post process for pypinyin outputs - # we discriminate i, ii and iii - if c == v: - assert c in punctuation - phone = [c] - word2ph.append(1) - else: - v_without_tone = v[:-1] - tone = v[-1] - - pinyin = c + v_without_tone - assert tone in "12345" - - if c: - # 多音节 - v_rep_map = { - "uei": "ui", - "iou": "iu", - "uen": "un", - } - if v_without_tone in v_rep_map.keys(): - pinyin = c + v_rep_map[v_without_tone] - else: - # 单音节 - pinyin_rep_map = { - "ing": "ying", - "i": "yi", - "in": "yin", - "u": "wu", - } - if pinyin in pinyin_rep_map.keys(): - pinyin = pinyin_rep_map[pinyin] - else: - single_rep_map = { - "v": "yu", - "e": "e", - "i": "y", - "u": "w", - } - if pinyin[0] in single_rep_map.keys(): - pinyin = single_rep_map[pinyin[0]] + pinyin[1:] - - assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin) - new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ") - new_v = new_v + tone - phone = [new_c, new_v] - word2ph.append(len(phone)) - - phones_list += phone + for phones, item_word2ph, _segment in g2p_segments(segments): + phones_list += phones + word2ph += item_word2ph return phones_list, word2ph diff --git a/GPT_SoVITS/text/g2pw/cuda_api.py b/GPT_SoVITS/text/g2pw/cuda_api.py new file mode 100644 index 00000000..e1a84748 --- /dev/null +++ b/GPT_SoVITS/text/g2pw/cuda_api.py @@ -0,0 +1,670 @@ +import ctypes +import fcntl +import os +import subprocess +import threading +import time +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Deque, Dict, List, Tuple + +import numpy as np + +from .onnx_api import _G2PWBaseOnnxConverter + + +class G2PWCudaError(RuntimeError): + pass + + +@dataclass +class G2PWBatchTask: + model_input: Dict[str, np.ndarray] + created_at: float = field(default_factory=time.perf_counter) + enqueued_at: float = 0.0 + done_event: threading.Event = field(default_factory=threading.Event) + output: np.ndarray | None = None + profile: Dict[str, float] = field(default_factory=dict) + error: Exception | None = None + + +_ROOT_DIR = Path(__file__).resolve().parents[3] +_PACKAGE_DIR = Path(__file__).resolve().parent +_OUTPUT_DIR = _ROOT_DIR / "outputs" / "g2pw_cuda_bridge" +_WRAPPER_SOURCE = _PACKAGE_DIR / "g2pw_cuda_bridge.cpp" +_LOCK_PATH = _OUTPUT_DIR / "build.lock" + + +def _env_flag(name: str, default: bool) -> int: + raw = os.environ.get(name) + if raw is None: + return 1 if default else 0 + return 0 if raw.strip().lower() in {"0", "false", "no", "off"} else 1 + + +def _env_int(name: str, default: int) -> int: + raw = os.environ.get(name) + if raw is None or raw.strip() == "": + return int(default) + return int(raw) + + +def _resolve_cuda_root() -> Path: + env_root = os.environ.get("GPTSOVITS_G2PW_CUDA_ROOT", "").strip() + candidates = [ + env_root, + _ROOT_DIR / "third_party" / "g2pw-cu", + ] + for candidate in candidates: + if not candidate: + continue + path = Path(candidate).expanduser().resolve() + if path.exists(): + return path + checked = [ + str(Path(candidate).expanduser().resolve()) + for candidate in candidates + if str(candidate).strip() != "" + ] + raise G2PWCudaError( + "Cannot locate g2pw-cu root. " + "Expected one of: " + f"{checked}. " + "Recommended: clone https://github.com/baicai-1145/g2pw-cu.git into " + f"{(_ROOT_DIR / 'third_party' / 'g2pw-cu').as_posix()} " + "or set GPTSOVITS_G2PW_CUDA_ROOT explicitly." + ) + + +def _resolve_runtime_paths() -> tuple[Path, Path, Path]: + cuda_root = _resolve_cuda_root() + runtime_lib = Path( + os.environ.get("GPTSOVITS_G2PW_CUDA_RUNTIME_LIB", str(cuda_root / "build" / "libg2pw_runtime.so")) + ).expanduser() + manifest_path = Path( + os.environ.get("GPTSOVITS_G2PW_CUDA_MANIFEST", str(cuda_root / "artifacts" / "model" / "manifest.txt")) + ).expanduser() + weights_path = Path( + os.environ.get("GPTSOVITS_G2PW_CUDA_WEIGHTS", str(cuda_root / "artifacts" / "model" / "weights.bin")) + ).expanduser() + for path in (runtime_lib, manifest_path, weights_path): + if not path.exists(): + raise G2PWCudaError(f"Missing g2pw-cu artifact: {path}") + return runtime_lib.resolve(), manifest_path.resolve(), weights_path.resolve() + + +def _build_bridge(wrapper_output: Path, runtime_lib: Path) -> None: + _OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + compile_cmd = [ + os.environ.get("CXX", "g++"), + "-O3", + "-std=c++17", + "-shared", + "-fPIC", + str(_WRAPPER_SOURCE), + "-I", + str(runtime_lib.parent.parent / "include"), + "-L", + str(runtime_lib.parent), + "-lg2pw_runtime", + f"-Wl,-rpath,{runtime_lib.parent}", + "-o", + str(wrapper_output), + ] + result = subprocess.run(compile_cmd, capture_output=True, text=True, check=False) + if result.returncode != 0: + raise G2PWCudaError( + "Failed to build g2pw-cu bridge:\n" + f"cmd={' '.join(compile_cmd)}\n" + f"stdout={result.stdout}\n" + f"stderr={result.stderr}" + ) + + +def _ensure_bridge_built(runtime_lib: Path) -> Path: + wrapper_output = _OUTPUT_DIR / "g2pw_cuda_bridge.so" + _OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + with _LOCK_PATH.open("w", encoding="utf-8") as lock_file: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + needs_build = not wrapper_output.exists() + if not needs_build: + so_mtime = wrapper_output.stat().st_mtime + needs_build = so_mtime < _WRAPPER_SOURCE.stat().st_mtime or so_mtime < runtime_lib.stat().st_mtime + if needs_build: + tmp_output = wrapper_output.with_suffix(".tmp.so") + if tmp_output.exists(): + tmp_output.unlink() + _build_bridge(tmp_output, runtime_lib) + tmp_output.replace(wrapper_output) + return wrapper_output + + +def _load_bridge(): + runtime_lib, manifest_path, weights_path = _resolve_runtime_paths() + bridge_path = _ensure_bridge_built(runtime_lib) + global_mode = getattr(ctypes, "RTLD_GLOBAL", getattr(os, "RTLD_GLOBAL", 0)) + ctypes.CDLL(str(runtime_lib), mode=global_mode) + lib = ctypes.CDLL(str(bridge_path)) + lib.g2pw_runtime_create.argtypes = [ + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ] + lib.g2pw_runtime_create.restype = ctypes.c_void_p + lib.g2pw_runtime_destroy.argtypes = [ctypes.c_void_p] + lib.g2pw_runtime_destroy.restype = None + lib.g2pw_runtime_last_error.argtypes = [ctypes.c_void_p] + lib.g2pw_runtime_last_error.restype = ctypes.c_char_p + lib.g2pw_runtime_num_labels.argtypes = [ctypes.c_void_p] + lib.g2pw_runtime_num_labels.restype = ctypes.c_int + lib.g2pw_runtime_run.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int32, + ctypes.c_int32, + ctypes.c_void_p, + ] + lib.g2pw_runtime_run.restype = ctypes.c_int + return lib, manifest_path, weights_path, runtime_lib + + +def _gemm_precision_value() -> int: + precision = os.environ.get("GPTSOVITS_G2PW_CUDA_GEMM_PRECISION", "fp32").strip().lower() + if precision == "fp16": + return 1 + if precision == "bf16": + return 2 + return 0 + + +class G2PWRuntimeWrapper: + def __init__(self, shard_index: int = 0) -> None: + self.lib, self.manifest_path, self.weights_path, self.runtime_lib = _load_bridge() + self.shard_index = int(shard_index) + self.device_ordinal = _env_int("GPTSOVITS_G2PW_CUDA_DEVICE", 0) + self.allow_tensor_cores = _env_flag("GPTSOVITS_G2PW_CUDA_ALLOW_TENSOR_CORES", False) + self.use_cublaslt_bias_epilogue = _env_flag("GPTSOVITS_G2PW_CUDA_USE_CUBLASLT_BIAS_EPILOGUE", False) + self.enable_profiling = _env_flag("GPTSOVITS_G2PW_CUDA_ENABLE_PROFILE", False) + self.enable_cuda_graph = _env_flag("GPTSOVITS_G2PW_CUDA_ENABLE_GRAPH", True) + self.dump_graph_cache_stats = _env_flag("GPTSOVITS_G2PW_CUDA_DUMP_GRAPH_CACHE_STATS", False) + self.full_graph_cache_limit = _env_int("GPTSOVITS_G2PW_CUDA_FULL_GRAPH_CACHE_LIMIT", 0) + self.tail_graph_cache_limit = _env_int("GPTSOVITS_G2PW_CUDA_TAIL_GRAPH_CACHE_LIMIT", 0) + self.gemm_precision = _gemm_precision_value() + self.lock = threading.Lock() + self.handle = None + self.max_batch_size = 0 + self.max_seq_len = 0 + self.num_labels = 0 + self.batch_enabled = _env_flag("GPTSOVITS_G2PW_CUDA_BATCHING", True) != 0 + self.batch_window_s = max(0.0, float(_env_int("GPTSOVITS_G2PW_CUDA_BATCH_WINDOW_MS", 1)) / 1000.0) + self.batch_max_requests = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_REQUESTS", 64)) + self.batch_max_rows = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_ROWS", 96)) + self.batch_max_tokens = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_TOKENS", 4096)) + self.batch_condition = threading.Condition() + self.pending_tasks: Deque[G2PWBatchTask] = deque() + self.batch_total_tasks = 0 + self.batch_total_batches = 0 + self.batch_total_rows = 0 + self.batch_total_queue_wait_ms = 0.0 + self.batch_queue_wait_peak_ms = 0.0 + self.batch_total_collect_wait_ms = 0.0 + self.batch_collect_wait_peak_ms = 0.0 + self.batch_total_run_ms = 0.0 + self.batch_run_peak_ms = 0.0 + self.batch_rows_peak = 0 + self.batch_requests_peak = 0 + self.batch_pending_peak = 0 + self.closed = False + self._ensure_capacity( + batch_size=max(1, _env_int("GPTSOVITS_G2PW_CUDA_MAX_BATCH_SIZE", 256)), + seq_len=max(1, _env_int("GPTSOVITS_G2PW_CUDA_MAX_SEQ_LEN", 128)), + ) + self.batch_worker = None + if self.batch_enabled: + self.batch_worker = threading.Thread( + target=self._batch_loop, + name=f"g2pw-cuda-batch-worker-{self.shard_index}", + daemon=True, + ) + self.batch_worker.start() + + def _destroy_handle(self) -> None: + if self.handle: + self.lib.g2pw_runtime_destroy(self.handle) + self.handle = None + + def close(self) -> None: + with self.batch_condition: + self.closed = True + self.batch_condition.notify_all() + self._destroy_handle() + + def __del__(self): + try: + self.close() + except Exception: + pass + + def _last_error(self) -> str: + if not self.handle: + return "uninitialized runtime" + message = self.lib.g2pw_runtime_last_error(self.handle) + return "" if not message else message.decode("utf-8", errors="replace") + + def _create_handle(self, batch_size: int, seq_len: int) -> None: + new_handle = self.lib.g2pw_runtime_create( + str(self.manifest_path).encode("utf-8"), + str(self.weights_path).encode("utf-8"), + int(self.device_ordinal), + int(batch_size), + int(seq_len), + int(self.full_graph_cache_limit), + int(self.tail_graph_cache_limit), + int(self.allow_tensor_cores), + int(self.use_cublaslt_bias_epilogue), + int(self.enable_profiling), + int(self.enable_cuda_graph), + int(self.dump_graph_cache_stats), + int(self.gemm_precision), + ) + if not new_handle: + raise G2PWCudaError("g2pw-cu returned null runtime handle") + self.handle = new_handle + self.max_batch_size = int(batch_size) + self.max_seq_len = int(seq_len) + self.num_labels = int(self.lib.g2pw_runtime_num_labels(self.handle)) + last_error = self._last_error() + if self.num_labels <= 0 or last_error: + self.close() + raise G2PWCudaError(f"Failed to initialize g2pw-cu runtime: {last_error or 'num_labels <= 0'}") + + def _ensure_capacity(self, batch_size: int, seq_len: int) -> None: + target_batch = max(1, int(batch_size)) + target_seq = max(1, int(seq_len)) + if self.handle and target_batch <= self.max_batch_size and target_seq <= self.max_seq_len: + return + next_batch = max(target_batch, self.max_batch_size * 2 if self.max_batch_size else 0) + next_seq = max(target_seq, self.max_seq_len * 2 if self.max_seq_len else 0) + self._destroy_handle() + self._create_handle(batch_size=next_batch, seq_len=next_seq) + + @staticmethod + def _normalize_model_input(model_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + input_ids = np.ascontiguousarray(model_input["input_ids"], dtype=np.int64) + token_type_ids = np.ascontiguousarray(model_input["token_type_ids"], dtype=np.int64) + attention_masks = np.ascontiguousarray(model_input["attention_masks"], dtype=np.int64) + phoneme_masks = np.ascontiguousarray(model_input["phoneme_masks"], dtype=np.float32) + char_ids = np.ascontiguousarray(model_input["char_ids"], dtype=np.int64) + position_ids = np.ascontiguousarray(model_input["position_ids"], dtype=np.int64) + batch_size = int(char_ids.shape[0]) + if input_ids.shape[0] == 1 and batch_size > 1: + input_ids = np.ascontiguousarray(np.repeat(input_ids, batch_size, axis=0), dtype=np.int64) + token_type_ids = np.ascontiguousarray(np.repeat(token_type_ids, batch_size, axis=0), dtype=np.int64) + attention_masks = np.ascontiguousarray(np.repeat(attention_masks, batch_size, axis=0), dtype=np.int64) + return { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_masks": attention_masks, + "phoneme_masks": phoneme_masks, + "char_ids": char_ids, + "position_ids": position_ids, + } + + def _run_direct(self, model_input: Dict[str, np.ndarray]) -> np.ndarray: + normalized = self._normalize_model_input(model_input) + input_ids = normalized["input_ids"] + token_type_ids = normalized["token_type_ids"] + attention_masks = normalized["attention_masks"] + phoneme_masks = normalized["phoneme_masks"] + char_ids = normalized["char_ids"] + position_ids = normalized["position_ids"] + batch_size = int(char_ids.shape[0]) + seq_len = int(input_ids.shape[1]) + probs = np.empty((batch_size, self.num_labels), dtype=np.float32) + with self.lock: + self._ensure_capacity(batch_size=batch_size, seq_len=seq_len) + status = self.lib.g2pw_runtime_run( + self.handle, + input_ids.ctypes.data_as(ctypes.c_void_p), + token_type_ids.ctypes.data_as(ctypes.c_void_p), + attention_masks.ctypes.data_as(ctypes.c_void_p), + phoneme_masks.ctypes.data_as(ctypes.c_void_p), + char_ids.ctypes.data_as(ctypes.c_void_p), + position_ids.ctypes.data_as(ctypes.c_void_p), + batch_size, + seq_len, + probs.ctypes.data_as(ctypes.c_void_p), + ) + if int(status) != 0: + raise G2PWCudaError(f"g2pw-cu inference failed: {self._last_error()}") + return probs + + def _can_append_task(self, tasks: List[G2PWBatchTask], candidate: G2PWBatchTask) -> bool: + request_count = len(tasks) + 1 + if request_count > self.batch_max_requests: + return False + total_rows = sum(int(item.model_input["char_ids"].shape[0]) for item in tasks) + int( + candidate.model_input["char_ids"].shape[0] + ) + if total_rows > self.batch_max_rows: + return False + total_tokens = sum( + int(item.model_input["char_ids"].shape[0]) * int(item.model_input["input_ids"].shape[1]) for item in tasks + ) + int(candidate.model_input["char_ids"].shape[0]) * int(candidate.model_input["input_ids"].shape[1]) + return total_tokens <= self.batch_max_tokens + + def _merge_batch_inputs(self, tasks: List[G2PWBatchTask]) -> Tuple[Dict[str, np.ndarray], List[Tuple[int, int]]]: + normalized_inputs = [self._normalize_model_input(task.model_input) for task in tasks] + total_rows = sum(int(item["char_ids"].shape[0]) for item in normalized_inputs) + max_seq_len = max(int(item["input_ids"].shape[1]) for item in normalized_inputs) + input_ids = np.zeros((total_rows, max_seq_len), dtype=np.int64) + token_type_ids = np.zeros((total_rows, max_seq_len), dtype=np.int64) + attention_masks = np.zeros((total_rows, max_seq_len), dtype=np.int64) + phoneme_masks = np.zeros((total_rows, normalized_inputs[0]["phoneme_masks"].shape[1]), dtype=np.float32) + char_ids = np.zeros((total_rows,), dtype=np.int64) + position_ids = np.zeros((total_rows,), dtype=np.int64) + slices: List[Tuple[int, int]] = [] + cursor = 0 + for item in normalized_inputs: + rows = int(item["char_ids"].shape[0]) + seq_len = int(item["input_ids"].shape[1]) + next_cursor = cursor + rows + input_ids[cursor:next_cursor, :seq_len] = item["input_ids"] + token_type_ids[cursor:next_cursor, :seq_len] = item["token_type_ids"] + attention_masks[cursor:next_cursor, :seq_len] = item["attention_masks"] + phoneme_masks[cursor:next_cursor] = item["phoneme_masks"] + char_ids[cursor:next_cursor] = item["char_ids"] + position_ids[cursor:next_cursor] = item["position_ids"] + slices.append((cursor, next_cursor)) + cursor = next_cursor + return { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_masks": attention_masks, + "phoneme_masks": phoneme_masks, + "char_ids": char_ids, + "position_ids": position_ids, + }, slices + + def _finish_task( + self, + task: G2PWBatchTask, + output: np.ndarray | None = None, + profile: Dict[str, float] | None = None, + error: Exception | None = None, + ) -> None: + task.output = output + task.profile = dict(profile or {}) + task.error = error + task.done_event.set() + + def _batch_loop(self) -> None: + while True: + with self.batch_condition: + while not self.pending_tasks and not self.closed: + self.batch_condition.wait() + if self.closed and not self.pending_tasks: + return + first_task = self.pending_tasks.popleft() + batch_tasks = [first_task] + collect_started = time.perf_counter() + deadline = collect_started + self.batch_window_s + while True: + if len(batch_tasks) >= self.batch_max_requests: + break + remaining = deadline - time.perf_counter() + if remaining <= 0.0: + break + if not self.pending_tasks: + self.batch_condition.wait(timeout=remaining) + continue + candidate = self.pending_tasks[0] + if not self._can_append_task(batch_tasks, candidate): + break + batch_tasks.append(self.pending_tasks.popleft()) + collect_wait_ms = max(0.0, (time.perf_counter() - collect_started) * 1000.0) + + now = time.perf_counter() + queue_wait_values = [max(0.0, (now - task.enqueued_at) * 1000.0) for task in batch_tasks] + try: + merged_input, row_slices = self._merge_batch_inputs(batch_tasks) + run_started = time.perf_counter() + merged_output = self._run_direct(merged_input) + run_ms = max(0.0, (time.perf_counter() - run_started) * 1000.0) + for task, (start, end) in zip(batch_tasks, row_slices): + task_rows = int(task.model_input["char_ids"].shape[0]) + task_seq_len = int(task.model_input["input_ids"].shape[1]) + self._finish_task( + task, + output=np.ascontiguousarray(merged_output[start:end]), + profile={ + "g2pw_runtime_queue_wait_ms": float(max(0.0, (run_started - task.enqueued_at) * 1000.0)), + "g2pw_runtime_collect_wait_ms": float(collect_wait_ms), + "g2pw_runtime_run_ms": float(run_ms), + "g2pw_runtime_batch_rows": float(sum(int(item.model_input["char_ids"].shape[0]) for item in batch_tasks)), + "g2pw_runtime_batch_requests": float(len(batch_tasks)), + "g2pw_runtime_task_rows": float(task_rows), + "g2pw_runtime_task_seq_len": float(task_seq_len), + "g2pw_runtime_shard_index": float(self.shard_index), + }, + ) + except Exception as exc: + run_ms = 0.0 + for task in batch_tasks: + self._finish_task(task, error=exc) + finally: + with self.batch_condition: + self.batch_total_batches += 1 + self.batch_total_tasks += len(batch_tasks) + self.batch_total_rows += sum(int(task.model_input["char_ids"].shape[0]) for task in batch_tasks) + self.batch_total_queue_wait_ms += float(sum(queue_wait_values)) + self.batch_queue_wait_peak_ms = max(self.batch_queue_wait_peak_ms, max(queue_wait_values or [0.0])) + self.batch_total_collect_wait_ms += float(collect_wait_ms) * float(len(batch_tasks)) + self.batch_collect_wait_peak_ms = max(self.batch_collect_wait_peak_ms, float(collect_wait_ms)) + self.batch_total_run_ms += float(run_ms) + self.batch_run_peak_ms = max(self.batch_run_peak_ms, float(run_ms)) + self.batch_rows_peak = max( + self.batch_rows_peak, sum(int(task.model_input["char_ids"].shape[0]) for task in batch_tasks) + ) + self.batch_requests_peak = max(self.batch_requests_peak, len(batch_tasks)) + + def _submit_batched(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]: + task = G2PWBatchTask(model_input=model_input) + with self.batch_condition: + if self.closed: + raise G2PWCudaError("g2pw-cu batch worker already closed") + task.enqueued_at = time.perf_counter() + self.pending_tasks.append(task) + self.batch_pending_peak = max(self.batch_pending_peak, len(self.pending_tasks)) + self.batch_condition.notify_all() + task.done_event.wait() + if task.error is not None: + raise task.error + assert task.output is not None + return task.output, dict(task.profile) + + def snapshot(self) -> Dict[str, float | int | bool]: + with self.batch_condition: + average_tasks_per_batch = ( + float(self.batch_total_tasks) / float(self.batch_total_batches) if self.batch_total_batches > 0 else 0.0 + ) + average_rows_per_batch = ( + float(self.batch_total_rows) / float(self.batch_total_batches) if self.batch_total_batches > 0 else 0.0 + ) + average_queue_wait_ms = ( + float(self.batch_total_queue_wait_ms) / float(self.batch_total_tasks) if self.batch_total_tasks > 0 else 0.0 + ) + average_collect_wait_ms = ( + float(self.batch_total_collect_wait_ms) / float(self.batch_total_tasks) + if self.batch_total_tasks > 0 + else 0.0 + ) + return { + "shard_index": int(self.shard_index), + "enabled": bool(self.batch_enabled), + "window_ms": float(self.batch_window_s * 1000.0), + "max_requests": int(self.batch_max_requests), + "max_rows": int(self.batch_max_rows), + "max_tokens": int(self.batch_max_tokens), + "pending": int(len(self.pending_tasks)), + "pending_peak": int(self.batch_pending_peak), + "total_batches": int(self.batch_total_batches), + "total_tasks": int(self.batch_total_tasks), + "total_rows": int(self.batch_total_rows), + "avg_tasks_per_batch": float(average_tasks_per_batch), + "avg_rows_per_batch": float(average_rows_per_batch), + "avg_queue_wait_ms": float(average_queue_wait_ms), + "queue_wait_peak_ms": float(self.batch_queue_wait_peak_ms), + "avg_collect_wait_ms": float(average_collect_wait_ms), + "collect_wait_peak_ms": float(self.batch_collect_wait_peak_ms), + "run_total_ms": float(self.batch_total_run_ms), + "run_peak_ms": float(self.batch_run_peak_ms), + "batch_rows_peak": int(self.batch_rows_peak), + "batch_requests_peak": int(self.batch_requests_peak), + } + + def pending_rows(self) -> int: + with self.batch_condition: + return int(sum(int(task.model_input["char_ids"].shape[0]) for task in self.pending_tasks)) + + def pending_count(self) -> int: + with self.batch_condition: + return int(len(self.pending_tasks)) + + def run_with_profile(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]: + if not self.batch_enabled: + started = time.perf_counter() + output = self._run_direct(model_input) + return output, { + "g2pw_runtime_queue_wait_ms": 0.0, + "g2pw_runtime_collect_wait_ms": 0.0, + "g2pw_runtime_run_ms": float((time.perf_counter() - started) * 1000.0), + "g2pw_runtime_batch_rows": float(model_input["char_ids"].shape[0]), + "g2pw_runtime_batch_requests": 1.0, + "g2pw_runtime_task_rows": float(model_input["char_ids"].shape[0]), + "g2pw_runtime_task_seq_len": float(model_input["input_ids"].shape[1]), + "g2pw_runtime_shard_index": float(self.shard_index), + } + return self._submit_batched(model_input) + + def run(self, model_input: Dict[str, np.ndarray]) -> np.ndarray: + output, _profile = self.run_with_profile(model_input) + return output + + +class G2PWRuntimePool: + def __init__(self) -> None: + self.worker_count = max(1, _env_int("GPTSOVITS_G2PW_CUDA_WORKERS", 2)) + self.shards = [G2PWRuntimeWrapper(shard_index=index) for index in range(self.worker_count)] + self.lock = threading.Lock() + + def _pick_shard(self) -> G2PWRuntimeWrapper: + with self.lock: + return min( + self.shards, + key=lambda shard: ( + shard.pending_rows(), + shard.pending_count(), + shard.snapshot().get("avg_queue_wait_ms", 0.0), + ), + ) + + def run_with_profile(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]: + shard = self._pick_shard() + output, profile = shard.run_with_profile(model_input) + profile["g2pw_runtime_pool_workers"] = float(self.worker_count) + return output, profile + + def run(self, model_input: Dict[str, np.ndarray]) -> np.ndarray: + output, _profile = self.run_with_profile(model_input) + return output + + def snapshot(self) -> Dict[str, float | int | bool | List[Dict[str, float | int | bool]]]: + shard_snapshots = [dict(shard.snapshot()) for shard in self.shards] + avg_queue_wait_ms = 0.0 + total_tasks = 0.0 + pending = 0 + pending_peak = 0 + total_batches = 0 + total_rows = 0 + batch_rows_peak = 0 + batch_requests_peak = 0 + for snapshot in shard_snapshots: + tasks = float(snapshot.get("total_tasks", 0.0)) + avg_queue_wait_ms += float(snapshot.get("avg_queue_wait_ms", 0.0)) * tasks + total_tasks += tasks + pending += int(snapshot.get("pending", 0)) + pending_peak = max(pending_peak, int(snapshot.get("pending_peak", 0))) + total_batches += int(snapshot.get("total_batches", 0)) + total_rows += int(snapshot.get("total_rows", 0)) + batch_rows_peak = max(batch_rows_peak, int(snapshot.get("batch_rows_peak", 0))) + batch_requests_peak = max(batch_requests_peak, int(snapshot.get("batch_requests_peak", 0))) + return { + "worker_count": int(self.worker_count), + "pending": int(pending), + "pending_peak": int(pending_peak), + "total_batches": int(total_batches), + "total_tasks": int(total_tasks), + "total_rows": int(total_rows), + "avg_queue_wait_ms": float(avg_queue_wait_ms / total_tasks) if total_tasks > 0 else 0.0, + "batch_rows_peak": int(batch_rows_peak), + "batch_requests_peak": int(batch_requests_peak), + "shards": shard_snapshots, + } + + +class G2PWCudaConverter(_G2PWBaseOnnxConverter): + def __init__( + self, + model_dir: str = "G2PWModel/", + style: str = "bopomofo", + model_source: str = None, + enable_non_tradional_chinese: bool = False, + ): + super().__init__( + model_dir=model_dir, + style=style, + model_source=model_source, + enable_non_tradional_chinese=enable_non_tradional_chinese, + ) + self.runtime = G2PWRuntimePool() + self.backend = "cuda" + primary_runtime = self.runtime.shards[0] + self.device = f"cuda:{primary_runtime.device_ordinal}" + self.checkpoint_path = str(primary_runtime.weights_path) + self.providers = ["g2pw-cu"] + + def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]: + probs = self.runtime.run(model_input) + preds = np.argmax(probs, axis=1).tolist() + confidences = probs[np.arange(len(preds)), preds].astype(np.float32, copy=False).tolist() + return [self.labels[pred] for pred in preds], confidences + + def _predict_with_profile(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float], Dict[str, float]]: + started = time.perf_counter() + probs, runtime_profile = self.runtime.run_with_profile(model_input) + preds = np.argmax(probs, axis=1).tolist() + confidences = probs[np.arange(len(preds)), preds].astype(np.float32, copy=False).tolist() + profile = dict(runtime_profile) + profile["g2pw_runtime_total_ms"] = float((time.perf_counter() - started) * 1000.0) + profile["g2pw_predict_ms"] = float(profile["g2pw_runtime_total_ms"]) + return [self.labels[pred] for pred in preds], confidences, profile + + def snapshot(self) -> Dict[str, float | int | bool]: + return dict(self.runtime.snapshot()) diff --git a/GPT_SoVITS/text/g2pw/g2pw.py b/GPT_SoVITS/text/g2pw/g2pw.py index 08525e91..ccd05a1b 100644 --- a/GPT_SoVITS/text/g2pw/g2pw.py +++ b/GPT_SoVITS/text/g2pw/g2pw.py @@ -8,6 +8,7 @@ from pypinyin.core import Pinyin, Style from pypinyin.seg.simpleseg import simple_seg from pypinyin.converter import UltimateConverter from pypinyin.contrib.tone_convert import to_tone +from .cuda_api import G2PWCudaConverter from .onnx_api import G2PWOnnxConverter current_file_path = os.path.dirname(__file__) @@ -27,12 +28,36 @@ class G2PWPinyin(Pinyin): tone_sandhi=False, **kwargs, ): - self._g2pw = G2PWOnnxConverter( - model_dir=model_dir, - style="pinyin", - model_source=model_source, - enable_non_tradional_chinese=enable_non_tradional_chinese, - ) + backend = os.environ.get("GPTSOVITS_G2PW_BACKEND", "cuda").strip().lower() + last_error = None + self._g2pw = None + if backend in {"cuda", "auto"}: + try: + self._g2pw = G2PWCudaConverter( + model_dir=model_dir, + style="pinyin", + model_source=model_source, + enable_non_tradional_chinese=enable_non_tradional_chinese, + ) + except Exception as exc: + last_error = exc + strict_mode = os.environ.get("GPTSOVITS_G2PW_CUDA_STRICT", "0").strip().lower() in { + "1", + "true", + "yes", + "on", + } + if backend == "cuda" and strict_mode: + raise + if self._g2pw is None: + self._g2pw = G2PWOnnxConverter( + model_dir=model_dir, + style="pinyin", + model_source=model_source, + enable_non_tradional_chinese=enable_non_tradional_chinese, + ) + if last_error is not None: + print(f"[g2pw] cuda backend unavailable, fallback to onnx: {last_error}") self._converter = Converter( self._g2pw, v_to_u=v_to_u, diff --git a/GPT_SoVITS/text/g2pw/g2pw_cuda_bridge.cpp b/GPT_SoVITS/text/g2pw/g2pw_cuda_bridge.cpp new file mode 100644 index 00000000..dc8f29a8 --- /dev/null +++ b/GPT_SoVITS/text/g2pw/g2pw_cuda_bridge.cpp @@ -0,0 +1,183 @@ +#include +#include +#include +#include + +#include "g2pw/runtime.h" + +namespace { + +struct G2PWRuntimeHandle { + std::unique_ptr runtime; + std::string last_error; + int num_labels = 0; +}; + +void SetError(G2PWRuntimeHandle* handle, const g2pw::Status& status) { + if (handle == nullptr) { + return; + } + handle->last_error = status.message; +} + +g2pw::RuntimeConfig BuildConfig( + int device_ordinal, + int max_batch_size, + int max_seq_len, + int full_graph_cache_limit, + int tail_graph_cache_limit, + int allow_tensor_cores, + int use_cublaslt_bias_epilogue, + int enable_profiling, + int enable_cuda_graph, + int dump_graph_cache_stats, + int gemm_precision) { + g2pw::RuntimeConfig config{}; + config.device_ordinal = device_ordinal; + config.max_batch_size = max_batch_size; + config.max_seq_len = max_seq_len; + config.full_graph_cache_limit = full_graph_cache_limit; + config.tail_graph_cache_limit = tail_graph_cache_limit; + config.allow_tensor_cores = allow_tensor_cores != 0; + config.use_cublaslt_bias_epilogue = use_cublaslt_bias_epilogue != 0; + config.enable_profiling = enable_profiling != 0; + config.enable_cuda_graph = enable_cuda_graph != 0; + config.dump_graph_cache_stats = dump_graph_cache_stats != 0; + switch (gemm_precision) { + case 1: + config.gemm_precision = g2pw::GemmPrecision::kFp16; + break; + case 2: + config.gemm_precision = g2pw::GemmPrecision::kBf16; + break; + default: + config.gemm_precision = g2pw::GemmPrecision::kFp32; + break; + } + return config; +} + +} // namespace + +extern "C" { + +void* g2pw_runtime_create( + const char* manifest_path, + const char* binary_path, + int device_ordinal, + int max_batch_size, + int max_seq_len, + int full_graph_cache_limit, + int tail_graph_cache_limit, + int allow_tensor_cores, + int use_cublaslt_bias_epilogue, + int enable_profiling, + int enable_cuda_graph, + int dump_graph_cache_stats, + int gemm_precision) { + auto* handle = new G2PWRuntimeHandle(); + try { + if (manifest_path == nullptr || binary_path == nullptr) { + handle->last_error = "manifest_path and binary_path must be non-null"; + return handle; + } + g2pw::RuntimeConfig config = BuildConfig( + device_ordinal, + max_batch_size, + max_seq_len, + full_graph_cache_limit, + tail_graph_cache_limit, + allow_tensor_cores, + use_cublaslt_bias_epilogue, + enable_profiling, + enable_cuda_graph, + dump_graph_cache_stats, + gemm_precision); + g2pw::Status status = g2pw::Runtime::Create( + config, + std::string(manifest_path), + std::string(binary_path), + &handle->runtime); + if (!status.ok()) { + SetError(handle, status); + return handle; + } + handle->num_labels = handle->runtime != nullptr ? handle->runtime->weights().manifest().num_labels : 0; + handle->last_error.clear(); + return handle; + } catch (const std::exception& exc) { + handle->last_error = exc.what(); + return handle; + } catch (...) { + handle->last_error = "unknown exception"; + return handle; + } +} + +void g2pw_runtime_destroy(void* raw_handle) { + auto* handle = static_cast(raw_handle); + delete handle; +} + +const char* g2pw_runtime_last_error(void* raw_handle) { + auto* handle = static_cast(raw_handle); + if (handle == nullptr) { + return "invalid runtime handle"; + } + return handle->last_error.c_str(); +} + +int g2pw_runtime_num_labels(void* raw_handle) { + auto* handle = static_cast(raw_handle); + if (handle == nullptr || handle->runtime == nullptr) { + return 0; + } + return handle->num_labels; +} + +int g2pw_runtime_run( + void* raw_handle, + const std::int64_t* input_ids, + const std::int64_t* token_type_ids, + const std::int64_t* attention_mask, + const float* phoneme_mask, + const std::int64_t* char_ids, + const std::int64_t* position_ids, + std::int32_t batch_size, + std::int32_t seq_len, + float* probs) { + auto* handle = static_cast(raw_handle); + if (handle == nullptr || handle->runtime == nullptr) { + return static_cast(g2pw::StatusCode::kInvalidArgument); + } + try { + g2pw::InferenceInputs inputs{}; + inputs.input_ids = input_ids; + inputs.token_type_ids = token_type_ids; + inputs.attention_mask = attention_mask; + inputs.phoneme_mask = phoneme_mask; + inputs.char_ids = char_ids; + inputs.position_ids = position_ids; + inputs.batch_size = batch_size; + inputs.seq_len = seq_len; + + g2pw::InferenceOutputs outputs{}; + outputs.probs = probs; + + const g2pw::Status status = handle->runtime->Run(inputs, outputs); + if (!status.ok()) { + SetError(handle, status); + return static_cast(status.code); + } + handle->last_error.clear(); + return static_cast(g2pw::StatusCode::kOk); + } catch (const std::exception& exc) { + handle->last_error = exc.what(); + return static_cast(g2pw::StatusCode::kInternalError); + } catch (...) { + handle->last_error = "unknown exception"; + return static_cast(g2pw::StatusCode::kInternalError); + } +} + +} diff --git a/GPT_SoVITS/text/g2pw/onnx_api.py b/GPT_SoVITS/text/g2pw/onnx_api.py index 3c2b0169..f6d7fab7 100644 --- a/GPT_SoVITS/text/g2pw/onnx_api.py +++ b/GPT_SoVITS/text/g2pw/onnx_api.py @@ -3,6 +3,7 @@ import json import os +import time import warnings import zipfile from typing import Any, Dict, List, Tuple @@ -71,6 +72,23 @@ def _find_first_existing_file(*paths: str) -> str: raise FileNotFoundError(f"Files not found: {paths}") +def _resolve_tokenizer_source(model_source: str | None) -> str: + candidate_paths = [] + if model_source: + candidate_paths.append(model_source) + repo_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")) + candidate_paths.extend( + [ + os.path.join(repo_root, "pretrained_models", "g2pw-chinese"), + os.path.join(repo_root, "pretrained_models", "chinese-roberta-wwm-ext-large"), + ] + ) + for candidate in candidate_paths: + if candidate and os.path.exists(candidate): + return candidate + return model_source or "bert-base-chinese" + + def download_and_decompress(model_dir: str = "G2PWModel/"): if not os.path.exists(model_dir): parent_directory = os.path.dirname(model_dir) @@ -106,9 +124,9 @@ class _G2PWBaseOnnxConverter: self.model_dir = download_and_decompress(model_dir) self.config = load_config(config_path=os.path.join(self.model_dir, "config.py"), use_default=True) - self.model_source = model_source if model_source else self.config.model_source + self.model_source = _resolve_tokenizer_source(model_source if model_source else self.config.model_source) self.enable_opencc = enable_non_tradional_chinese - self.tokenizer = AutoTokenizer.from_pretrained(self.model_source) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_source, local_files_only=True) polyphonic_chars_path = os.path.join(self.model_dir, "POLYPHONIC_CHARS.txt") monophonic_chars_path = os.path.join(self.model_dir, "MONOPHONIC_CHARS.txt") @@ -200,6 +218,10 @@ class _G2PWBaseOnnxConverter: return None def __call__(self, sentences: List[str]) -> List[List[str]]: + results, _profile = self.predict_sentences_with_profile(sentences) + return results + + def predict_sentences_with_profile(self, sentences: List[str]) -> Tuple[List[List[str]], Dict[str, float]]: if isinstance(sentences, str): sentences = [sentences] @@ -213,7 +235,7 @@ class _G2PWBaseOnnxConverter: texts, model_query_ids, result_query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences) if len(texts) == 0: - return partial_results + return partial_results, {} model_input = prepare_onnx_input( tokenizer=self.tokenizer, @@ -229,12 +251,21 @@ class _G2PWBaseOnnxConverter: ) if not model_input: - return partial_results + return partial_results, {} + predict_profile: Dict[str, float] = {} if self.enable_sentence_dedup: - preds, _confidences = self._predict_with_sentence_dedup(model_input=model_input, texts=texts) + preds, _confidences, predict_profile = self._predict_with_sentence_dedup_profiled( + model_input=model_input, + texts=texts, + ) else: - preds, _confidences = self._predict(model_input=model_input) + if hasattr(self, "_predict_with_profile"): + preds, _confidences, predict_profile = self._predict_with_profile(model_input=model_input) + else: + predict_started = time.perf_counter() + preds, _confidences = self._predict(model_input=model_input) + predict_profile["g2pw_predict_ms"] = float((time.perf_counter() - predict_started) * 1000.0) if self.config.use_char_phoneme: preds = [pred.split(" ")[1] for pred in preds] @@ -243,7 +274,7 @@ class _G2PWBaseOnnxConverter: for sent_id, query_id, pred in zip(sent_ids, result_query_ids, preds): results[sent_id][query_id] = self.style_convert_func(pred) - return results + return results, predict_profile def _prepare_data( self, sentences: List[str] @@ -314,6 +345,52 @@ class _G2PWBaseOnnxConverter: return preds, confidences + def _predict_with_sentence_dedup_profiled( + self, + model_input: Dict[str, Any], + texts: List[str], + ) -> Tuple[List[str], List[float], Dict[str, float]]: + if len(texts) <= 1: + if hasattr(self, "_predict_with_profile"): + return self._predict_with_profile(model_input=model_input) + predict_started = time.perf_counter() + preds, confidences = self._predict(model_input=model_input) + return preds, confidences, {"g2pw_predict_ms": float((time.perf_counter() - predict_started) * 1000.0)} + + grouped_indices: Dict[str, List[int]] = {} + for idx, text in enumerate(texts): + grouped_indices.setdefault(text, []).append(idx) + + if all(len(indices) == 1 for indices in grouped_indices.values()): + if hasattr(self, "_predict_with_profile"): + return self._predict_with_profile(model_input=model_input) + predict_started = time.perf_counter() + preds, confidences = self._predict(model_input=model_input) + return preds, confidences, {"g2pw_predict_ms": float((time.perf_counter() - predict_started) * 1000.0)} + + preds: List[str] = [""] * len(texts) + confidences: List[float] = [0.0] * len(texts) + merged_profile: Dict[str, float] = {} + for indices in grouped_indices.values(): + group_input = {name: value[indices] for name, value in model_input.items()} + if len(indices) > 1: + for name in ("input_ids", "token_type_ids", "attention_masks"): + group_input[name] = group_input[name][:1] + if hasattr(self, "_predict_with_profile"): + group_preds, group_confidences, group_profile = self._predict_with_profile(model_input=group_input) + for key, value in dict(group_profile or {}).items(): + merged_profile[key] = float(merged_profile.get(key, 0.0)) + float(value) + else: + predict_started = time.perf_counter() + group_preds, group_confidences = self._predict(model_input=group_input) + merged_profile["g2pw_predict_ms"] = float( + merged_profile.get("g2pw_predict_ms", 0.0) + (time.perf_counter() - predict_started) * 1000.0 + ) + for output_idx, pred, confidence in zip(indices, group_preds, group_confidences): + preds[output_idx] = pred + confidences[output_idx] = confidence + return preds, confidences, merged_profile + class G2PWOnnxConverter(_G2PWBaseOnnxConverter): def __init__(