Implement G2PW processing enhancements in TTS framework

Add support for G2PW processing in the TTS system by introducing new methods and classes for handling G2PW segments. Update PrepareCoordinator to manage G2PW worker threads and integrate G2PW profiling into the existing framework. Enhance text preprocessing to identify segments requiring G2PW and streamline the resolution of these segments. This update improves the overall performance and maintainability of the TTS system by optimizing the handling of Chinese text processing.
This commit is contained in:
baicai-1145 2026-03-12 23:04:39 +08:00
parent 5cf68a91d3
commit 17cb2e5acf
10 changed files with 1417 additions and 149 deletions

View File

@ -529,6 +529,7 @@ class TTS:
self.bert_model,
self.bert_tokenizer,
self.configs.device,
version=self.configs.version,
bert_stage_limiter=self.prepare_bert_stage_limiter,
bert_batch_worker=self.prepare_bert_batch_worker,
)
@ -558,6 +559,16 @@ class TTS:
return None
def snapshot_prepare_runtime_components(self) -> dict:
g2pw_runtime = None
try:
from text import chinese2
g2pw_instance = getattr(chinese2, "g2pw", None)
g2pw_backend = None if g2pw_instance is None else getattr(g2pw_instance, "_g2pw", None)
if g2pw_backend is not None and hasattr(g2pw_backend, "snapshot"):
g2pw_runtime = dict(g2pw_backend.snapshot())
except Exception:
g2pw_runtime = None
return {
"text_cpu": {
"workers": int(self.prepare_text_cpu_workers),
@ -587,6 +598,7 @@ class TTS:
"text_preprocessor": (
None if self.text_preprocessor is None or not hasattr(self.text_preprocessor, "snapshot") else self.text_preprocessor.snapshot()
),
"g2pw": g2pw_runtime,
}
def _build_text_cpu_admission_state(self) -> dict:
@ -1204,6 +1216,9 @@ class TTS:
def prepare_text_segments(self, text: str, language: str):
return self.text_preprocessor.preprocess_text_segments(text, language, self.configs.version)
def resolve_g2pw_segments(self, prepared_segments, profile: dict | None = None):
return self.text_preprocessor.resolve_g2pw_segments(prepared_segments, profile=profile)
def build_text_features_from_segments(self, prepared_segments, profile: dict | None = None):
return self.text_preprocessor.build_phones_and_bert_from_segments(prepared_segments, profile=profile)

View File

@ -101,6 +101,7 @@ class PreparedTextSegment:
phones: List[int]
word2ph: Optional[List[int]]
norm_text: str
needs_g2pw: bool = False
class TextPreprocessor:
@ -109,12 +110,14 @@ class TextPreprocessor:
bert_model: AutoModelForMaskedLM,
tokenizer: AutoTokenizer,
device: torch.device,
version: str = "v2",
bert_stage_limiter: StageLimiter | None = None,
bert_batch_worker: PrepareBertBatchWorker | None = None,
):
self.bert_model = bert_model
self.tokenizer = tokenizer
self.device = device
self.version = str(version)
self.bert_stage_limiter = bert_stage_limiter
self.bert_batch_worker = bert_batch_worker
@ -261,15 +264,66 @@ class TextPreprocessor:
phones=list(payload["phones"]),
word2ph=None if payload["word2ph"] is None else list(payload["word2ph"]),
norm_text=str(payload["norm_text"]),
needs_g2pw=bool(payload.get("needs_g2pw", False)),
)
for payload in payloads
]
def resolve_g2pw_segments(
self,
prepared_segments: List[PreparedTextSegment],
profile: Dict | None = None,
) -> List[PreparedTextSegment]:
zh_indices = [index for index, segment in enumerate(prepared_segments) if bool(segment.needs_g2pw)]
if not zh_indices:
return prepared_segments
from text import chinese2
normalized_segments = [prepared_segments[index].norm_text for index in zh_indices]
resolved_segments, g2pw_profile = chinese2.g2p_segments(normalized_segments, return_profile=True)
self._accumulate_profile(profile, "g2pw_prepare_ms", g2pw_profile.get("g2pw_prepare_ms", 0.0))
self._accumulate_profile(profile, "g2pw_predict_ms", g2pw_profile.get("g2pw_predict_ms", 0.0))
self._accumulate_profile(profile, "g2pw_post_ms", g2pw_profile.get("g2pw_post_ms", 0.0))
self._accumulate_profile(profile, "g2pw_total_ms", g2pw_profile.get("g2pw_total_ms", 0.0))
self._accumulate_profile(profile, "g2pw_runtime_total_ms", g2pw_profile.get("g2pw_runtime_total_ms", 0.0))
self._accumulate_profile(profile, "g2pw_runtime_queue_wait_ms", g2pw_profile.get("g2pw_runtime_queue_wait_ms", 0.0))
self._accumulate_profile(
profile,
"g2pw_runtime_collect_wait_ms",
g2pw_profile.get("g2pw_runtime_collect_wait_ms", 0.0),
)
self._accumulate_profile(profile, "g2pw_runtime_run_ms", g2pw_profile.get("g2pw_runtime_run_ms", 0.0))
self._update_profile_peak(
profile,
"g2pw_runtime_batch_rows_peak",
g2pw_profile.get("g2pw_runtime_batch_rows", 0.0),
)
self._update_profile_peak(
profile,
"g2pw_runtime_batch_requests_peak",
g2pw_profile.get("g2pw_runtime_batch_requests", 0.0),
)
self._update_profile_peak(
profile,
"g2pw_runtime_pool_workers",
g2pw_profile.get("g2pw_runtime_pool_workers", 0.0),
)
for index, (phones, word2ph, norm_text) in zip(zh_indices, resolved_segments):
prepared_segments[index] = PreparedTextSegment(
language=prepared_segments[index].language,
phones=list(cleaned_text_to_sequence(phones, self.version)),
word2ph=None if word2ph is None else list(word2ph),
norm_text=str(norm_text),
needs_g2pw=False,
)
return prepared_segments
def build_phones_and_bert_from_segments(
self,
prepared_segments: List[PreparedTextSegment],
profile: Dict | None = None,
) -> Tuple[list, torch.Tensor, str]:
prepared_segments = self.resolve_g2pw_segments(prepared_segments, profile=profile)
phones_list: List[List[int]] = []
bert_list: List[torch.Tensor] = []
norm_text_list: List[str] = []
@ -402,6 +456,7 @@ class TextPreprocessor:
prepared_segments: List[PreparedTextSegment],
profile: Dict | None = None,
) -> Tuple[list, torch.Tensor, str]:
prepared_segments = self.resolve_g2pw_segments(prepared_segments, profile=profile)
segment_jobs = self._build_async_segment_jobs(prepared_segments, profile)
pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = []
for segment_index, segment in enumerate(prepared_segments):
@ -473,6 +528,8 @@ class TextPreprocessor:
prompt_profile: Dict | None = None,
target_profile: Dict | None = None,
) -> Tuple[Tuple[list, torch.Tensor, str], Tuple[list, torch.Tensor, str]]:
prompt_segments = self.resolve_g2pw_segments(prompt_segments, profile=prompt_profile)
target_segments = self.resolve_g2pw_segments(target_segments, profile=target_profile)
prompt_jobs = self._build_async_segment_jobs(prompt_segments, prompt_profile)
target_jobs = self._build_async_segment_jobs(target_segments, target_profile)
pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = []

View File

@ -120,6 +120,15 @@ class PrepareCoordinator:
max_workers=self.text_feature_workers,
thread_name_prefix="prepare-text-feature",
)
g2pw_default_workers = max(8, int(getattr(tts, "prepare_text_cpu_workers", 8) or 8))
self.g2pw_workers = max(
1,
int(os.environ.get("GPTSOVITS_PREPARE_G2PW_WORKERS", str(g2pw_default_workers))),
)
self.g2pw_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=self.g2pw_workers,
thread_name_prefix="prepare-g2pw",
)
ref_audio_default_workers = max(1, int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4")))
self.ref_audio_workers = max(
1,
@ -130,12 +139,17 @@ class PrepareCoordinator:
thread_name_prefix="prepare-ref-audio",
)
text_cpu_gate_default = max(0, int(getattr(tts, "prepare_text_cpu_workers", 0) or 0))
g2pw_gate_default = max(0, int(self.g2pw_workers))
text_feature_gate_default = max(0, int(self.text_feature_workers))
ref_audio_gate_default = max(0, int(self.ref_audio_workers))
self.text_cpu_gate = AsyncStageGate(
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_INFLIGHT", str(text_cpu_gate_default))),
poll_ms=gate_poll_ms,
)
self.g2pw_gate = AsyncStageGate(
int(os.environ.get("GPTSOVITS_PREPARE_G2PW_MAX_INFLIGHT", str(g2pw_gate_default))),
poll_ms=gate_poll_ms,
)
self.text_feature_gate = AsyncStageGate(
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_MAX_INFLIGHT", str(text_feature_gate_default))),
poll_ms=gate_poll_ms,
@ -172,6 +186,7 @@ class PrepareCoordinator:
"peak_inflight": int(self.peak_inflight),
"max_inflight": int(self.max_inflight),
"text_feature_workers": int(self.text_feature_workers),
"g2pw_workers": int(self.g2pw_workers),
"ref_audio_workers": int(self.ref_audio_workers),
}
runtime_snapshot_fn = getattr(self.tts, "snapshot_prepare_runtime_components", None)
@ -182,6 +197,7 @@ class PrepareCoordinator:
snapshot["prepare_runtime_state"] = None
snapshot["prepare_stage_gates"] = {
"text_cpu": self.text_cpu_gate.snapshot(),
"g2pw": self.g2pw_gate.snapshot(),
"text_feature": self.text_feature_gate.snapshot(),
"ref_audio": self.ref_audio_gate.snapshot(),
"ref_load": self.ref_load_gate.snapshot(),
@ -204,6 +220,11 @@ class PrepareCoordinator:
def _prepare_text_cpu(self, text: str, language: str):
return self.tts.prepare_text_segments(text, language)
def _resolve_g2pw_segments(self, prepared_segments):
profile: Dict[str, float] = {}
resolved_segments = self.tts.resolve_g2pw_segments(prepared_segments, profile=profile)
return resolved_segments, profile
def _load_ref_audio_raw(self, ref_audio_path: str):
return self.tts._load_ref_audio_raw(ref_audio_path)
@ -225,8 +246,15 @@ class PrepareCoordinator:
dtype=(dtype if dtype is not None else None) or __import__("torch").float32,
)
def _build_text_features(self, prepared_segments, language: str, cpu_run_ms: float) -> PreparedTextFeatures:
profile: Dict[str, float] = {"cpu_preprocess_ms": float(cpu_run_ms)}
def _build_text_features(
self,
prepared_segments,
language: str,
cpu_run_ms: float,
base_profile: Dict[str, float] | None = None,
) -> PreparedTextFeatures:
profile: Dict[str, float] = dict(base_profile or {})
profile["cpu_preprocess_ms"] = float(cpu_run_ms)
branch_start = time.perf_counter()
phones, bert_features, norm_text = self.tts.build_text_features_from_segments(prepared_segments, profile=profile)
total_ms = float(cpu_run_ms + (time.perf_counter() - branch_start) * 1000.0)
@ -291,10 +319,53 @@ class PrepareCoordinator:
prepared_segments,
language,
cpu_run_ms,
None,
)
finally:
self.text_feature_gate.release()
async def _run_g2pw_stage(self, prepared_segments) -> ProfiledResult:
has_pending = any(bool(getattr(segment, "needs_g2pw", False)) for segment in (prepared_segments or []))
if not has_pending:
submit_at = time.perf_counter()
return ProfiledResult(
result=prepared_segments,
submit_at=float(submit_at),
started_at=float(submit_at),
finished_at=float(submit_at),
profile={},
)
await self.g2pw_gate.acquire()
try:
profiled = await self._run_on_executor(self.g2pw_executor, self._resolve_g2pw_segments, prepared_segments)
result, stage_profile = profiled.result
return ProfiledResult(
result=result,
submit_at=float(profiled.submit_at),
started_at=float(profiled.started_at),
finished_at=float(profiled.finished_at),
profile=dict(stage_profile),
)
finally:
self.g2pw_gate.release()
async def _run_g2pw_pair_stage(self, prompt_segments, target_segments) -> tuple[ProfiledResult, ProfiledResult]:
prompt_is_empty = len(prompt_segments or []) == 0
target_task = asyncio.create_task(self._run_g2pw_stage(target_segments))
if not prompt_is_empty:
prompt_task = asyncio.create_task(self._run_g2pw_stage(prompt_segments))
return await asyncio.gather(prompt_task, target_task)
target_profiled = await target_task
submit_at = time.perf_counter()
prompt_profiled = ProfiledResult(
result=prompt_segments,
submit_at=float(submit_at),
started_at=float(submit_at),
finished_at=float(submit_at),
profile={},
)
return prompt_profiled, target_profiled
@staticmethod
def _estimate_text_feature_run_ms(profile: Dict[str, float]) -> float:
return float(
@ -310,12 +381,32 @@ class PrepareCoordinator:
target_segments,
prompt_cpu_run_ms: float,
target_cpu_run_ms: float,
prompt_base_profile: Dict[str, float] | None = None,
target_base_profile: Dict[str, float] | None = None,
) -> tuple[ProfiledResult, ProfiledResult]:
prompt_is_empty = len(prompt_segments or []) == 0
if self.text_feature_executor is not None:
target_feature_task = asyncio.create_task(self._run_text_feature_stage(target_segments, None, target_cpu_run_ms))
target_feature_task = asyncio.create_task(
self._run_on_executor(
self.text_feature_executor,
self._build_text_features,
target_segments,
None,
target_cpu_run_ms,
target_base_profile,
)
)
if not prompt_is_empty:
prompt_feature_task = asyncio.create_task(self._run_text_feature_stage(prompt_segments, None, prompt_cpu_run_ms))
prompt_feature_task = asyncio.create_task(
self._run_on_executor(
self.text_feature_executor,
self._build_text_features,
prompt_segments,
None,
prompt_cpu_run_ms,
prompt_base_profile,
)
)
return await asyncio.gather(prompt_feature_task, target_feature_task)
target_profiled = await target_feature_task
submit_at = time.perf_counter()
@ -328,7 +419,8 @@ class PrepareCoordinator:
return prompt_profiled, target_profiled
await self.text_feature_gate.acquire()
target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)}
target_profile: Dict[str, float] = dict(target_base_profile or {})
target_profile["cpu_preprocess_ms"] = float(target_cpu_run_ms)
submit_at = time.perf_counter()
started_at = float(submit_at)
try:
@ -377,7 +469,8 @@ class PrepareCoordinator:
target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile)
return prompt_profiled, target_profiled
prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)}
prompt_profile: Dict[str, float] = dict(prompt_base_profile or {})
prompt_profile["cpu_preprocess_ms"] = float(prompt_cpu_run_ms)
prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async(
prompt_segments,
target_segments,
@ -589,20 +682,31 @@ class PrepareCoordinator:
cpu_stage: PreparedCpuStage,
) -> tuple[T2SRequestState, float, float]:
try:
text_pair_start = time.perf_counter()
ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(cpu_stage.spec.ref_audio_path)))
text_feature_pair_task = asyncio.create_task(
self._run_text_feature_pair_stage(
g2pw_pair_start = time.perf_counter()
g2pw_pair_task = asyncio.create_task(
self._run_g2pw_pair_stage(
cpu_stage.prompt_cpu_profiled.result,
cpu_stage.target_cpu_profiled.result,
cpu_stage.prompt_cpu_profiled.run_ms,
cpu_stage.target_cpu_profiled.run_ms,
)
)
(prompt_feature_profiled, target_feature_profiled), ref_audio_profiled = await asyncio.gather(
text_feature_pair_task,
ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(cpu_stage.spec.ref_audio_path)))
(prompt_g2pw_profiled, target_g2pw_profiled), ref_audio_profiled = await asyncio.gather(
g2pw_pair_task,
ref_audio_task,
)
g2pw_pair_end = time.perf_counter()
text_pair_start = time.perf_counter()
text_feature_pair_task = asyncio.create_task(
self._run_text_feature_pair_stage(
prompt_g2pw_profiled.result,
target_g2pw_profiled.result,
cpu_stage.prompt_cpu_profiled.run_ms,
cpu_stage.target_cpu_profiled.run_ms,
prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}),
target_base_profile=dict(target_g2pw_profiled.profile or {}),
)
)
prompt_feature_profiled, target_feature_profiled = await text_feature_pair_task
text_pair_end = time.perf_counter()
state = build_request_state_from_parts(
tts=self.tts,
@ -619,6 +723,17 @@ class PrepareCoordinator:
"prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms,
"executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0),
"text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0),
"g2pw_pair_ms": max(0.0, (g2pw_pair_end - g2pw_pair_start) * 1000.0),
"prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms,
"prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms,
"prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)),
"prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)),
"prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)),
"text_g2pw_queue_ms": target_g2pw_profiled.queue_ms,
"text_g2pw_run_ms": target_g2pw_profiled.run_ms,
"text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)),
"text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)),
"text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)),
"prompt_text_parallel_future_wait_ms": 0.0,
"prompt_text_parallel_future_executor_queue_ms": 0.0,
"prompt_text_parallel_future_run_ms": 0.0,

View File

@ -241,6 +241,24 @@ def build_request_state_from_parts(
prompt_result.profile.get("bert_high_pressure_mode_peak", 0.0)
),
"prompt_text_bert_batch_window_ms": float(prompt_result.profile.get("bert_batch_window_ms", 0.0)),
"prompt_text_g2pw_total_ms": float(prompt_result.profile.get("g2pw_total_ms", 0.0)),
"prompt_text_g2pw_prepare_ms": float(prompt_result.profile.get("g2pw_prepare_ms", 0.0)),
"prompt_text_g2pw_predict_ms": float(prompt_result.profile.get("g2pw_predict_ms", 0.0)),
"prompt_text_g2pw_post_ms": float(prompt_result.profile.get("g2pw_post_ms", 0.0)),
"prompt_text_g2pw_runtime_total_ms": float(prompt_result.profile.get("g2pw_runtime_total_ms", 0.0)),
"prompt_text_g2pw_runtime_queue_wait_ms": float(
prompt_result.profile.get("g2pw_runtime_queue_wait_ms", 0.0)
),
"prompt_text_g2pw_runtime_collect_wait_ms": float(
prompt_result.profile.get("g2pw_runtime_collect_wait_ms", 0.0)
),
"prompt_text_g2pw_runtime_run_ms": float(prompt_result.profile.get("g2pw_runtime_run_ms", 0.0)),
"prompt_text_g2pw_runtime_batch_rows_peak": float(
prompt_result.profile.get("g2pw_runtime_batch_rows_peak", 0.0)
),
"prompt_text_g2pw_runtime_batch_requests_peak": float(
prompt_result.profile.get("g2pw_runtime_batch_requests_peak", 0.0)
),
"prompt_text_parallel_future_wait_ms": 0.0,
"prompt_text_parallel_future_executor_queue_ms": 0.0,
"prompt_text_parallel_future_run_ms": float(prompt_result.total_ms),
@ -267,6 +285,18 @@ def build_request_state_from_parts(
),
"text_bert_high_pressure_mode_peak": float(target_result.profile.get("bert_high_pressure_mode_peak", 0.0)),
"text_bert_batch_window_ms": float(target_result.profile.get("bert_batch_window_ms", 0.0)),
"text_g2pw_total_ms": float(target_result.profile.get("g2pw_total_ms", 0.0)),
"text_g2pw_prepare_ms": float(target_result.profile.get("g2pw_prepare_ms", 0.0)),
"text_g2pw_predict_ms": float(target_result.profile.get("g2pw_predict_ms", 0.0)),
"text_g2pw_post_ms": float(target_result.profile.get("g2pw_post_ms", 0.0)),
"text_g2pw_runtime_total_ms": float(target_result.profile.get("g2pw_runtime_total_ms", 0.0)),
"text_g2pw_runtime_queue_wait_ms": float(target_result.profile.get("g2pw_runtime_queue_wait_ms", 0.0)),
"text_g2pw_runtime_collect_wait_ms": float(target_result.profile.get("g2pw_runtime_collect_wait_ms", 0.0)),
"text_g2pw_runtime_run_ms": float(target_result.profile.get("g2pw_runtime_run_ms", 0.0)),
"text_g2pw_runtime_batch_rows_peak": float(target_result.profile.get("g2pw_runtime_batch_rows_peak", 0.0)),
"text_g2pw_runtime_batch_requests_peak": float(
target_result.profile.get("g2pw_runtime_batch_requests_peak", 0.0)
),
"text_feature_pair_ms": float(max(prompt_result.total_ms, target_result.total_ms)),
"text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)),
"audio_load_ms": audio_load_ms,

View File

@ -8,6 +8,7 @@ sys.path.append(now_dir)
from text.LangSegmenter import LangSegmenter
from text import cleaned_text_to_sequence
from text import chinese2
from text.cleaner import clean_text
@ -83,16 +84,27 @@ def preprocess_text_segments_payload(
payloads: List[PreparedTextSegmentPayload] = []
total_phones_len = 0
for segment_text, segment_lang in zip(textlist, langlist):
phones, word2ph, norm_text = clean_text_segment(segment_text, segment_lang, version)
normalized_language = segment_lang.replace("all_", "")
if normalized_language == "zh":
norm_text = chinese2.text_normalize(segment_text)
phones = []
word2ph = None
needs_g2pw = True
estimated_phones_len = max(0, len(norm_text) * 2)
else:
phones, word2ph, norm_text = clean_text_segment(segment_text, segment_lang, version)
needs_g2pw = False
estimated_phones_len = len(phones)
payloads.append(
{
"language": segment_lang.replace("all_", ""),
"language": normalized_language,
"phones": phones,
"word2ph": word2ph,
"norm_text": norm_text,
"needs_g2pw": needs_g2pw,
}
)
total_phones_len += len(phones)
total_phones_len += int(estimated_phones_len)
if not final and total_phones_len < 6:
return preprocess_text_segments_payload("." + text, language, version, final=True)

View File

@ -1,5 +1,6 @@
import os
import re
import time
import cn2an
from pypinyin import lazy_pinyin, Style
@ -77,6 +78,205 @@ def g2p(text):
return phones, word2ph
def _prepare_g2p_segments(segments):
prepared_segments = []
batch_inputs = []
for segment in segments:
processed_segment = re.sub("[a-zA-Z]+", "", segment)
seg_cut = psg.lcut(processed_segment)
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
prepared_segments.append(
{
"segment": processed_segment,
"seg_cut": seg_cut,
}
)
if processed_segment:
batch_inputs.append(processed_segment)
return prepared_segments, batch_inputs
def _build_segment_from_g2pw(segment: str, seg_cut, pinyins):
phones_list = []
word2ph = []
initials = []
finals = []
pre_word_length = 0
for word, pos in seg_cut:
sub_initials = []
sub_finals = []
now_word_length = pre_word_length + len(word)
if pos == "eng":
pre_word_length = now_word_length
continue
word_pinyins = pinyins[pre_word_length:now_word_length]
word_pinyins = correct_pronunciation(word, word_pinyins)
for pinyin in word_pinyins:
if pinyin[0].isalpha():
sub_initials.append(to_initials(pinyin))
sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True))
else:
sub_initials.append(pinyin)
sub_finals.append(pinyin)
pre_word_length = now_word_length
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
initials.append(sub_initials)
finals.append(sub_finals)
initials = sum(initials, [])
finals = sum(finals, [])
for c, v in zip(initials, finals):
raw_pinyin = c + v
if c == v:
assert c in punctuation
phone = [c]
word2ph.append(1)
else:
v_without_tone = v[:-1]
tone = v[-1]
pinyin = c + v_without_tone
assert tone in "12345"
if c:
v_rep_map = {
"uei": "ui",
"iou": "iu",
"uen": "un",
}
if v_without_tone in v_rep_map.keys():
pinyin = c + v_rep_map[v_without_tone]
else:
pinyin_rep_map = {
"ing": "ying",
"i": "yi",
"in": "yin",
"u": "wu",
}
if pinyin in pinyin_rep_map.keys():
pinyin = pinyin_rep_map[pinyin]
else:
single_rep_map = {
"v": "yu",
"e": "e",
"i": "y",
"u": "w",
}
if pinyin[0] in single_rep_map.keys():
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, segment, raw_pinyin)
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
new_v = new_v + tone
phone = [new_c, new_v]
word2ph.append(len(phone))
phones_list += phone
return phones_list, word2ph
def _build_segment_without_g2pw(segment: str, seg_cut):
initials = []
finals = []
for word, pos in seg_cut:
if pos == "eng":
continue
sub_initials, sub_finals = _get_initials_finals(word)
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
initials.append(sub_initials)
finals.append(sub_finals)
phones_list = []
word2ph = []
for c, v in zip(sum(initials, []), sum(finals, [])):
raw_pinyin = c + v
if c == v:
assert c in punctuation
phone = [c]
word2ph.append(1)
else:
v_without_tone = v[:-1]
tone = v[-1]
pinyin = c + v_without_tone
assert tone in "12345"
if c:
v_rep_map = {"uei": "ui", "iou": "iu", "uen": "un"}
if v_without_tone in v_rep_map:
pinyin = c + v_rep_map[v_without_tone]
else:
pinyin_rep_map = {"ing": "ying", "i": "yi", "in": "yin", "u": "wu"}
if pinyin in pinyin_rep_map:
pinyin = pinyin_rep_map[pinyin]
else:
single_rep_map = {"v": "yu", "e": "e", "i": "y", "u": "w"}
if pinyin[0] in single_rep_map:
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, segment, raw_pinyin)
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
new_v = new_v + tone
phone = [new_c, new_v]
word2ph.append(len(phone))
phones_list += phone
return phones_list, word2ph
def g2p_segments(segments, return_profile: bool = False):
prepare_start = time.perf_counter()
prepared_segments, batch_inputs = _prepare_g2p_segments(segments)
profile = {
"g2pw_prepare_ms": 0.0,
"g2pw_predict_ms": 0.0,
"g2pw_post_ms": 0.0,
"g2pw_runtime_total_ms": 0.0,
"g2pw_runtime_queue_wait_ms": 0.0,
"g2pw_runtime_collect_wait_ms": 0.0,
"g2pw_runtime_run_ms": 0.0,
"g2pw_runtime_batch_rows": 0.0,
"g2pw_runtime_batch_requests": 0.0,
"g2pw_runtime_pool_workers": 0.0,
"g2pw_runtime_shard_index": 0.0,
}
profile["g2pw_prepare_ms"] = float((time.perf_counter() - prepare_start) * 1000.0)
if is_g2pw and batch_inputs:
converter = g2pw._g2pw
if hasattr(converter, "predict_sentences_with_profile"):
g2pw_batch_results, predict_profile = converter.predict_sentences_with_profile(batch_inputs)
for key, value in dict(predict_profile or {}).items():
profile[key] = float(value)
else:
predict_start = time.perf_counter()
g2pw_batch_results = converter(batch_inputs)
profile["g2pw_predict_ms"] = float((time.perf_counter() - predict_start) * 1000.0)
else:
g2pw_batch_results = []
post_start = time.perf_counter()
results = []
batch_cursor = 0
for item in prepared_segments:
segment = item["segment"]
if not segment:
results.append(([], [], segment))
continue
if not is_g2pw:
phones, word2ph = _build_segment_without_g2pw(segment, item["seg_cut"])
results.append((phones, word2ph, segment))
continue
pinyins = g2pw_batch_results[batch_cursor]
batch_cursor += 1
phones, word2ph = _build_segment_from_g2pw(segment, item["seg_cut"], pinyins)
results.append((phones, word2ph, segment))
profile["g2pw_post_ms"] = float((time.perf_counter() - post_start) * 1000.0)
profile["g2pw_total_ms"] = float(profile["g2pw_prepare_ms"] + profile["g2pw_predict_ms"] + profile["g2pw_post_ms"])
if return_profile:
return results, profile
return results
def _get_initials_finals(word):
initials = []
finals = []
@ -180,125 +380,9 @@ def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) ->
def _g2p(segments):
phones_list = []
word2ph = []
g2pw_batch_results = []
g2pw_batch_cursor = 0
processed_segments = [re.sub("[a-zA-Z]+", "", seg) for seg in segments]
if is_g2pw:
batch_inputs = [seg for seg in processed_segments if seg]
g2pw_batch_results = g2pw._g2pw(batch_inputs) if batch_inputs else []
for seg in processed_segments:
pinyins = []
seg_cut = psg.lcut(seg)
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
initials = []
finals = []
if not is_g2pw:
for word, pos in seg_cut:
if pos == "eng":
continue
sub_initials, sub_finals = _get_initials_finals(word)
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
# 儿化
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
initials.append(sub_initials)
finals.append(sub_finals)
# assert len(sub_initials) == len(sub_finals) == len(word)
initials = sum(initials, [])
finals = sum(finals, [])
print("pypinyin结果", initials, finals)
else:
# g2pw采用整句推理批量推理逐句取结果
if seg:
pinyins = g2pw_batch_results[g2pw_batch_cursor]
g2pw_batch_cursor += 1
pre_word_length = 0
for word, pos in seg_cut:
sub_initials = []
sub_finals = []
now_word_length = pre_word_length + len(word)
if pos == "eng":
pre_word_length = now_word_length
continue
word_pinyins = pinyins[pre_word_length:now_word_length]
# 多音字消歧
word_pinyins = correct_pronunciation(word, word_pinyins)
for pinyin in word_pinyins:
if pinyin[0].isalpha():
sub_initials.append(to_initials(pinyin))
sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True))
else:
sub_initials.append(pinyin)
sub_finals.append(pinyin)
pre_word_length = now_word_length
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
# 儿化
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
initials.append(sub_initials)
finals.append(sub_finals)
initials = sum(initials, [])
finals = sum(finals, [])
# print("g2pw结果",initials,finals)
for c, v in zip(initials, finals):
raw_pinyin = c + v
# NOTE: post process for pypinyin outputs
# we discriminate i, ii and iii
if c == v:
assert c in punctuation
phone = [c]
word2ph.append(1)
else:
v_without_tone = v[:-1]
tone = v[-1]
pinyin = c + v_without_tone
assert tone in "12345"
if c:
# 多音节
v_rep_map = {
"uei": "ui",
"iou": "iu",
"uen": "un",
}
if v_without_tone in v_rep_map.keys():
pinyin = c + v_rep_map[v_without_tone]
else:
# 单音节
pinyin_rep_map = {
"ing": "ying",
"i": "yi",
"in": "yin",
"u": "wu",
}
if pinyin in pinyin_rep_map.keys():
pinyin = pinyin_rep_map[pinyin]
else:
single_rep_map = {
"v": "yu",
"e": "e",
"i": "y",
"u": "w",
}
if pinyin[0] in single_rep_map.keys():
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
new_v = new_v + tone
phone = [new_c, new_v]
word2ph.append(len(phone))
phones_list += phone
for phones, item_word2ph, _segment in g2p_segments(segments):
phones_list += phones
word2ph += item_word2ph
return phones_list, word2ph

View File

@ -0,0 +1,670 @@
import ctypes
import fcntl
import os
import subprocess
import threading
import time
from collections import deque
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Deque, Dict, List, Tuple
import numpy as np
from .onnx_api import _G2PWBaseOnnxConverter
class G2PWCudaError(RuntimeError):
pass
@dataclass
class G2PWBatchTask:
model_input: Dict[str, np.ndarray]
created_at: float = field(default_factory=time.perf_counter)
enqueued_at: float = 0.0
done_event: threading.Event = field(default_factory=threading.Event)
output: np.ndarray | None = None
profile: Dict[str, float] = field(default_factory=dict)
error: Exception | None = None
_ROOT_DIR = Path(__file__).resolve().parents[3]
_PACKAGE_DIR = Path(__file__).resolve().parent
_OUTPUT_DIR = _ROOT_DIR / "outputs" / "g2pw_cuda_bridge"
_WRAPPER_SOURCE = _PACKAGE_DIR / "g2pw_cuda_bridge.cpp"
_LOCK_PATH = _OUTPUT_DIR / "build.lock"
def _env_flag(name: str, default: bool) -> int:
raw = os.environ.get(name)
if raw is None:
return 1 if default else 0
return 0 if raw.strip().lower() in {"0", "false", "no", "off"} else 1
def _env_int(name: str, default: int) -> int:
raw = os.environ.get(name)
if raw is None or raw.strip() == "":
return int(default)
return int(raw)
def _resolve_cuda_root() -> Path:
env_root = os.environ.get("GPTSOVITS_G2PW_CUDA_ROOT", "").strip()
candidates = [
env_root,
_ROOT_DIR / "third_party" / "g2pw-cu",
]
for candidate in candidates:
if not candidate:
continue
path = Path(candidate).expanduser().resolve()
if path.exists():
return path
checked = [
str(Path(candidate).expanduser().resolve())
for candidate in candidates
if str(candidate).strip() != ""
]
raise G2PWCudaError(
"Cannot locate g2pw-cu root. "
"Expected one of: "
f"{checked}. "
"Recommended: clone https://github.com/baicai-1145/g2pw-cu.git into "
f"{(_ROOT_DIR / 'third_party' / 'g2pw-cu').as_posix()} "
"or set GPTSOVITS_G2PW_CUDA_ROOT explicitly."
)
def _resolve_runtime_paths() -> tuple[Path, Path, Path]:
cuda_root = _resolve_cuda_root()
runtime_lib = Path(
os.environ.get("GPTSOVITS_G2PW_CUDA_RUNTIME_LIB", str(cuda_root / "build" / "libg2pw_runtime.so"))
).expanduser()
manifest_path = Path(
os.environ.get("GPTSOVITS_G2PW_CUDA_MANIFEST", str(cuda_root / "artifacts" / "model" / "manifest.txt"))
).expanduser()
weights_path = Path(
os.environ.get("GPTSOVITS_G2PW_CUDA_WEIGHTS", str(cuda_root / "artifacts" / "model" / "weights.bin"))
).expanduser()
for path in (runtime_lib, manifest_path, weights_path):
if not path.exists():
raise G2PWCudaError(f"Missing g2pw-cu artifact: {path}")
return runtime_lib.resolve(), manifest_path.resolve(), weights_path.resolve()
def _build_bridge(wrapper_output: Path, runtime_lib: Path) -> None:
_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
compile_cmd = [
os.environ.get("CXX", "g++"),
"-O3",
"-std=c++17",
"-shared",
"-fPIC",
str(_WRAPPER_SOURCE),
"-I",
str(runtime_lib.parent.parent / "include"),
"-L",
str(runtime_lib.parent),
"-lg2pw_runtime",
f"-Wl,-rpath,{runtime_lib.parent}",
"-o",
str(wrapper_output),
]
result = subprocess.run(compile_cmd, capture_output=True, text=True, check=False)
if result.returncode != 0:
raise G2PWCudaError(
"Failed to build g2pw-cu bridge:\n"
f"cmd={' '.join(compile_cmd)}\n"
f"stdout={result.stdout}\n"
f"stderr={result.stderr}"
)
def _ensure_bridge_built(runtime_lib: Path) -> Path:
wrapper_output = _OUTPUT_DIR / "g2pw_cuda_bridge.so"
_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
with _LOCK_PATH.open("w", encoding="utf-8") as lock_file:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
needs_build = not wrapper_output.exists()
if not needs_build:
so_mtime = wrapper_output.stat().st_mtime
needs_build = so_mtime < _WRAPPER_SOURCE.stat().st_mtime or so_mtime < runtime_lib.stat().st_mtime
if needs_build:
tmp_output = wrapper_output.with_suffix(".tmp.so")
if tmp_output.exists():
tmp_output.unlink()
_build_bridge(tmp_output, runtime_lib)
tmp_output.replace(wrapper_output)
return wrapper_output
def _load_bridge():
runtime_lib, manifest_path, weights_path = _resolve_runtime_paths()
bridge_path = _ensure_bridge_built(runtime_lib)
global_mode = getattr(ctypes, "RTLD_GLOBAL", getattr(os, "RTLD_GLOBAL", 0))
ctypes.CDLL(str(runtime_lib), mode=global_mode)
lib = ctypes.CDLL(str(bridge_path))
lib.g2pw_runtime_create.argtypes = [
ctypes.c_char_p,
ctypes.c_char_p,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
]
lib.g2pw_runtime_create.restype = ctypes.c_void_p
lib.g2pw_runtime_destroy.argtypes = [ctypes.c_void_p]
lib.g2pw_runtime_destroy.restype = None
lib.g2pw_runtime_last_error.argtypes = [ctypes.c_void_p]
lib.g2pw_runtime_last_error.restype = ctypes.c_char_p
lib.g2pw_runtime_num_labels.argtypes = [ctypes.c_void_p]
lib.g2pw_runtime_num_labels.restype = ctypes.c_int
lib.g2pw_runtime_run.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int32,
ctypes.c_int32,
ctypes.c_void_p,
]
lib.g2pw_runtime_run.restype = ctypes.c_int
return lib, manifest_path, weights_path, runtime_lib
def _gemm_precision_value() -> int:
precision = os.environ.get("GPTSOVITS_G2PW_CUDA_GEMM_PRECISION", "fp32").strip().lower()
if precision == "fp16":
return 1
if precision == "bf16":
return 2
return 0
class G2PWRuntimeWrapper:
def __init__(self, shard_index: int = 0) -> None:
self.lib, self.manifest_path, self.weights_path, self.runtime_lib = _load_bridge()
self.shard_index = int(shard_index)
self.device_ordinal = _env_int("GPTSOVITS_G2PW_CUDA_DEVICE", 0)
self.allow_tensor_cores = _env_flag("GPTSOVITS_G2PW_CUDA_ALLOW_TENSOR_CORES", False)
self.use_cublaslt_bias_epilogue = _env_flag("GPTSOVITS_G2PW_CUDA_USE_CUBLASLT_BIAS_EPILOGUE", False)
self.enable_profiling = _env_flag("GPTSOVITS_G2PW_CUDA_ENABLE_PROFILE", False)
self.enable_cuda_graph = _env_flag("GPTSOVITS_G2PW_CUDA_ENABLE_GRAPH", True)
self.dump_graph_cache_stats = _env_flag("GPTSOVITS_G2PW_CUDA_DUMP_GRAPH_CACHE_STATS", False)
self.full_graph_cache_limit = _env_int("GPTSOVITS_G2PW_CUDA_FULL_GRAPH_CACHE_LIMIT", 0)
self.tail_graph_cache_limit = _env_int("GPTSOVITS_G2PW_CUDA_TAIL_GRAPH_CACHE_LIMIT", 0)
self.gemm_precision = _gemm_precision_value()
self.lock = threading.Lock()
self.handle = None
self.max_batch_size = 0
self.max_seq_len = 0
self.num_labels = 0
self.batch_enabled = _env_flag("GPTSOVITS_G2PW_CUDA_BATCHING", True) != 0
self.batch_window_s = max(0.0, float(_env_int("GPTSOVITS_G2PW_CUDA_BATCH_WINDOW_MS", 1)) / 1000.0)
self.batch_max_requests = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_REQUESTS", 64))
self.batch_max_rows = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_ROWS", 96))
self.batch_max_tokens = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_TOKENS", 4096))
self.batch_condition = threading.Condition()
self.pending_tasks: Deque[G2PWBatchTask] = deque()
self.batch_total_tasks = 0
self.batch_total_batches = 0
self.batch_total_rows = 0
self.batch_total_queue_wait_ms = 0.0
self.batch_queue_wait_peak_ms = 0.0
self.batch_total_collect_wait_ms = 0.0
self.batch_collect_wait_peak_ms = 0.0
self.batch_total_run_ms = 0.0
self.batch_run_peak_ms = 0.0
self.batch_rows_peak = 0
self.batch_requests_peak = 0
self.batch_pending_peak = 0
self.closed = False
self._ensure_capacity(
batch_size=max(1, _env_int("GPTSOVITS_G2PW_CUDA_MAX_BATCH_SIZE", 256)),
seq_len=max(1, _env_int("GPTSOVITS_G2PW_CUDA_MAX_SEQ_LEN", 128)),
)
self.batch_worker = None
if self.batch_enabled:
self.batch_worker = threading.Thread(
target=self._batch_loop,
name=f"g2pw-cuda-batch-worker-{self.shard_index}",
daemon=True,
)
self.batch_worker.start()
def _destroy_handle(self) -> None:
if self.handle:
self.lib.g2pw_runtime_destroy(self.handle)
self.handle = None
def close(self) -> None:
with self.batch_condition:
self.closed = True
self.batch_condition.notify_all()
self._destroy_handle()
def __del__(self):
try:
self.close()
except Exception:
pass
def _last_error(self) -> str:
if not self.handle:
return "uninitialized runtime"
message = self.lib.g2pw_runtime_last_error(self.handle)
return "" if not message else message.decode("utf-8", errors="replace")
def _create_handle(self, batch_size: int, seq_len: int) -> None:
new_handle = self.lib.g2pw_runtime_create(
str(self.manifest_path).encode("utf-8"),
str(self.weights_path).encode("utf-8"),
int(self.device_ordinal),
int(batch_size),
int(seq_len),
int(self.full_graph_cache_limit),
int(self.tail_graph_cache_limit),
int(self.allow_tensor_cores),
int(self.use_cublaslt_bias_epilogue),
int(self.enable_profiling),
int(self.enable_cuda_graph),
int(self.dump_graph_cache_stats),
int(self.gemm_precision),
)
if not new_handle:
raise G2PWCudaError("g2pw-cu returned null runtime handle")
self.handle = new_handle
self.max_batch_size = int(batch_size)
self.max_seq_len = int(seq_len)
self.num_labels = int(self.lib.g2pw_runtime_num_labels(self.handle))
last_error = self._last_error()
if self.num_labels <= 0 or last_error:
self.close()
raise G2PWCudaError(f"Failed to initialize g2pw-cu runtime: {last_error or 'num_labels <= 0'}")
def _ensure_capacity(self, batch_size: int, seq_len: int) -> None:
target_batch = max(1, int(batch_size))
target_seq = max(1, int(seq_len))
if self.handle and target_batch <= self.max_batch_size and target_seq <= self.max_seq_len:
return
next_batch = max(target_batch, self.max_batch_size * 2 if self.max_batch_size else 0)
next_seq = max(target_seq, self.max_seq_len * 2 if self.max_seq_len else 0)
self._destroy_handle()
self._create_handle(batch_size=next_batch, seq_len=next_seq)
@staticmethod
def _normalize_model_input(model_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
input_ids = np.ascontiguousarray(model_input["input_ids"], dtype=np.int64)
token_type_ids = np.ascontiguousarray(model_input["token_type_ids"], dtype=np.int64)
attention_masks = np.ascontiguousarray(model_input["attention_masks"], dtype=np.int64)
phoneme_masks = np.ascontiguousarray(model_input["phoneme_masks"], dtype=np.float32)
char_ids = np.ascontiguousarray(model_input["char_ids"], dtype=np.int64)
position_ids = np.ascontiguousarray(model_input["position_ids"], dtype=np.int64)
batch_size = int(char_ids.shape[0])
if input_ids.shape[0] == 1 and batch_size > 1:
input_ids = np.ascontiguousarray(np.repeat(input_ids, batch_size, axis=0), dtype=np.int64)
token_type_ids = np.ascontiguousarray(np.repeat(token_type_ids, batch_size, axis=0), dtype=np.int64)
attention_masks = np.ascontiguousarray(np.repeat(attention_masks, batch_size, axis=0), dtype=np.int64)
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_masks": attention_masks,
"phoneme_masks": phoneme_masks,
"char_ids": char_ids,
"position_ids": position_ids,
}
def _run_direct(self, model_input: Dict[str, np.ndarray]) -> np.ndarray:
normalized = self._normalize_model_input(model_input)
input_ids = normalized["input_ids"]
token_type_ids = normalized["token_type_ids"]
attention_masks = normalized["attention_masks"]
phoneme_masks = normalized["phoneme_masks"]
char_ids = normalized["char_ids"]
position_ids = normalized["position_ids"]
batch_size = int(char_ids.shape[0])
seq_len = int(input_ids.shape[1])
probs = np.empty((batch_size, self.num_labels), dtype=np.float32)
with self.lock:
self._ensure_capacity(batch_size=batch_size, seq_len=seq_len)
status = self.lib.g2pw_runtime_run(
self.handle,
input_ids.ctypes.data_as(ctypes.c_void_p),
token_type_ids.ctypes.data_as(ctypes.c_void_p),
attention_masks.ctypes.data_as(ctypes.c_void_p),
phoneme_masks.ctypes.data_as(ctypes.c_void_p),
char_ids.ctypes.data_as(ctypes.c_void_p),
position_ids.ctypes.data_as(ctypes.c_void_p),
batch_size,
seq_len,
probs.ctypes.data_as(ctypes.c_void_p),
)
if int(status) != 0:
raise G2PWCudaError(f"g2pw-cu inference failed: {self._last_error()}")
return probs
def _can_append_task(self, tasks: List[G2PWBatchTask], candidate: G2PWBatchTask) -> bool:
request_count = len(tasks) + 1
if request_count > self.batch_max_requests:
return False
total_rows = sum(int(item.model_input["char_ids"].shape[0]) for item in tasks) + int(
candidate.model_input["char_ids"].shape[0]
)
if total_rows > self.batch_max_rows:
return False
total_tokens = sum(
int(item.model_input["char_ids"].shape[0]) * int(item.model_input["input_ids"].shape[1]) for item in tasks
) + int(candidate.model_input["char_ids"].shape[0]) * int(candidate.model_input["input_ids"].shape[1])
return total_tokens <= self.batch_max_tokens
def _merge_batch_inputs(self, tasks: List[G2PWBatchTask]) -> Tuple[Dict[str, np.ndarray], List[Tuple[int, int]]]:
normalized_inputs = [self._normalize_model_input(task.model_input) for task in tasks]
total_rows = sum(int(item["char_ids"].shape[0]) for item in normalized_inputs)
max_seq_len = max(int(item["input_ids"].shape[1]) for item in normalized_inputs)
input_ids = np.zeros((total_rows, max_seq_len), dtype=np.int64)
token_type_ids = np.zeros((total_rows, max_seq_len), dtype=np.int64)
attention_masks = np.zeros((total_rows, max_seq_len), dtype=np.int64)
phoneme_masks = np.zeros((total_rows, normalized_inputs[0]["phoneme_masks"].shape[1]), dtype=np.float32)
char_ids = np.zeros((total_rows,), dtype=np.int64)
position_ids = np.zeros((total_rows,), dtype=np.int64)
slices: List[Tuple[int, int]] = []
cursor = 0
for item in normalized_inputs:
rows = int(item["char_ids"].shape[0])
seq_len = int(item["input_ids"].shape[1])
next_cursor = cursor + rows
input_ids[cursor:next_cursor, :seq_len] = item["input_ids"]
token_type_ids[cursor:next_cursor, :seq_len] = item["token_type_ids"]
attention_masks[cursor:next_cursor, :seq_len] = item["attention_masks"]
phoneme_masks[cursor:next_cursor] = item["phoneme_masks"]
char_ids[cursor:next_cursor] = item["char_ids"]
position_ids[cursor:next_cursor] = item["position_ids"]
slices.append((cursor, next_cursor))
cursor = next_cursor
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_masks": attention_masks,
"phoneme_masks": phoneme_masks,
"char_ids": char_ids,
"position_ids": position_ids,
}, slices
def _finish_task(
self,
task: G2PWBatchTask,
output: np.ndarray | None = None,
profile: Dict[str, float] | None = None,
error: Exception | None = None,
) -> None:
task.output = output
task.profile = dict(profile or {})
task.error = error
task.done_event.set()
def _batch_loop(self) -> None:
while True:
with self.batch_condition:
while not self.pending_tasks and not self.closed:
self.batch_condition.wait()
if self.closed and not self.pending_tasks:
return
first_task = self.pending_tasks.popleft()
batch_tasks = [first_task]
collect_started = time.perf_counter()
deadline = collect_started + self.batch_window_s
while True:
if len(batch_tasks) >= self.batch_max_requests:
break
remaining = deadline - time.perf_counter()
if remaining <= 0.0:
break
if not self.pending_tasks:
self.batch_condition.wait(timeout=remaining)
continue
candidate = self.pending_tasks[0]
if not self._can_append_task(batch_tasks, candidate):
break
batch_tasks.append(self.pending_tasks.popleft())
collect_wait_ms = max(0.0, (time.perf_counter() - collect_started) * 1000.0)
now = time.perf_counter()
queue_wait_values = [max(0.0, (now - task.enqueued_at) * 1000.0) for task in batch_tasks]
try:
merged_input, row_slices = self._merge_batch_inputs(batch_tasks)
run_started = time.perf_counter()
merged_output = self._run_direct(merged_input)
run_ms = max(0.0, (time.perf_counter() - run_started) * 1000.0)
for task, (start, end) in zip(batch_tasks, row_slices):
task_rows = int(task.model_input["char_ids"].shape[0])
task_seq_len = int(task.model_input["input_ids"].shape[1])
self._finish_task(
task,
output=np.ascontiguousarray(merged_output[start:end]),
profile={
"g2pw_runtime_queue_wait_ms": float(max(0.0, (run_started - task.enqueued_at) * 1000.0)),
"g2pw_runtime_collect_wait_ms": float(collect_wait_ms),
"g2pw_runtime_run_ms": float(run_ms),
"g2pw_runtime_batch_rows": float(sum(int(item.model_input["char_ids"].shape[0]) for item in batch_tasks)),
"g2pw_runtime_batch_requests": float(len(batch_tasks)),
"g2pw_runtime_task_rows": float(task_rows),
"g2pw_runtime_task_seq_len": float(task_seq_len),
"g2pw_runtime_shard_index": float(self.shard_index),
},
)
except Exception as exc:
run_ms = 0.0
for task in batch_tasks:
self._finish_task(task, error=exc)
finally:
with self.batch_condition:
self.batch_total_batches += 1
self.batch_total_tasks += len(batch_tasks)
self.batch_total_rows += sum(int(task.model_input["char_ids"].shape[0]) for task in batch_tasks)
self.batch_total_queue_wait_ms += float(sum(queue_wait_values))
self.batch_queue_wait_peak_ms = max(self.batch_queue_wait_peak_ms, max(queue_wait_values or [0.0]))
self.batch_total_collect_wait_ms += float(collect_wait_ms) * float(len(batch_tasks))
self.batch_collect_wait_peak_ms = max(self.batch_collect_wait_peak_ms, float(collect_wait_ms))
self.batch_total_run_ms += float(run_ms)
self.batch_run_peak_ms = max(self.batch_run_peak_ms, float(run_ms))
self.batch_rows_peak = max(
self.batch_rows_peak, sum(int(task.model_input["char_ids"].shape[0]) for task in batch_tasks)
)
self.batch_requests_peak = max(self.batch_requests_peak, len(batch_tasks))
def _submit_batched(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]:
task = G2PWBatchTask(model_input=model_input)
with self.batch_condition:
if self.closed:
raise G2PWCudaError("g2pw-cu batch worker already closed")
task.enqueued_at = time.perf_counter()
self.pending_tasks.append(task)
self.batch_pending_peak = max(self.batch_pending_peak, len(self.pending_tasks))
self.batch_condition.notify_all()
task.done_event.wait()
if task.error is not None:
raise task.error
assert task.output is not None
return task.output, dict(task.profile)
def snapshot(self) -> Dict[str, float | int | bool]:
with self.batch_condition:
average_tasks_per_batch = (
float(self.batch_total_tasks) / float(self.batch_total_batches) if self.batch_total_batches > 0 else 0.0
)
average_rows_per_batch = (
float(self.batch_total_rows) / float(self.batch_total_batches) if self.batch_total_batches > 0 else 0.0
)
average_queue_wait_ms = (
float(self.batch_total_queue_wait_ms) / float(self.batch_total_tasks) if self.batch_total_tasks > 0 else 0.0
)
average_collect_wait_ms = (
float(self.batch_total_collect_wait_ms) / float(self.batch_total_tasks)
if self.batch_total_tasks > 0
else 0.0
)
return {
"shard_index": int(self.shard_index),
"enabled": bool(self.batch_enabled),
"window_ms": float(self.batch_window_s * 1000.0),
"max_requests": int(self.batch_max_requests),
"max_rows": int(self.batch_max_rows),
"max_tokens": int(self.batch_max_tokens),
"pending": int(len(self.pending_tasks)),
"pending_peak": int(self.batch_pending_peak),
"total_batches": int(self.batch_total_batches),
"total_tasks": int(self.batch_total_tasks),
"total_rows": int(self.batch_total_rows),
"avg_tasks_per_batch": float(average_tasks_per_batch),
"avg_rows_per_batch": float(average_rows_per_batch),
"avg_queue_wait_ms": float(average_queue_wait_ms),
"queue_wait_peak_ms": float(self.batch_queue_wait_peak_ms),
"avg_collect_wait_ms": float(average_collect_wait_ms),
"collect_wait_peak_ms": float(self.batch_collect_wait_peak_ms),
"run_total_ms": float(self.batch_total_run_ms),
"run_peak_ms": float(self.batch_run_peak_ms),
"batch_rows_peak": int(self.batch_rows_peak),
"batch_requests_peak": int(self.batch_requests_peak),
}
def pending_rows(self) -> int:
with self.batch_condition:
return int(sum(int(task.model_input["char_ids"].shape[0]) for task in self.pending_tasks))
def pending_count(self) -> int:
with self.batch_condition:
return int(len(self.pending_tasks))
def run_with_profile(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]:
if not self.batch_enabled:
started = time.perf_counter()
output = self._run_direct(model_input)
return output, {
"g2pw_runtime_queue_wait_ms": 0.0,
"g2pw_runtime_collect_wait_ms": 0.0,
"g2pw_runtime_run_ms": float((time.perf_counter() - started) * 1000.0),
"g2pw_runtime_batch_rows": float(model_input["char_ids"].shape[0]),
"g2pw_runtime_batch_requests": 1.0,
"g2pw_runtime_task_rows": float(model_input["char_ids"].shape[0]),
"g2pw_runtime_task_seq_len": float(model_input["input_ids"].shape[1]),
"g2pw_runtime_shard_index": float(self.shard_index),
}
return self._submit_batched(model_input)
def run(self, model_input: Dict[str, np.ndarray]) -> np.ndarray:
output, _profile = self.run_with_profile(model_input)
return output
class G2PWRuntimePool:
def __init__(self) -> None:
self.worker_count = max(1, _env_int("GPTSOVITS_G2PW_CUDA_WORKERS", 2))
self.shards = [G2PWRuntimeWrapper(shard_index=index) for index in range(self.worker_count)]
self.lock = threading.Lock()
def _pick_shard(self) -> G2PWRuntimeWrapper:
with self.lock:
return min(
self.shards,
key=lambda shard: (
shard.pending_rows(),
shard.pending_count(),
shard.snapshot().get("avg_queue_wait_ms", 0.0),
),
)
def run_with_profile(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]:
shard = self._pick_shard()
output, profile = shard.run_with_profile(model_input)
profile["g2pw_runtime_pool_workers"] = float(self.worker_count)
return output, profile
def run(self, model_input: Dict[str, np.ndarray]) -> np.ndarray:
output, _profile = self.run_with_profile(model_input)
return output
def snapshot(self) -> Dict[str, float | int | bool | List[Dict[str, float | int | bool]]]:
shard_snapshots = [dict(shard.snapshot()) for shard in self.shards]
avg_queue_wait_ms = 0.0
total_tasks = 0.0
pending = 0
pending_peak = 0
total_batches = 0
total_rows = 0
batch_rows_peak = 0
batch_requests_peak = 0
for snapshot in shard_snapshots:
tasks = float(snapshot.get("total_tasks", 0.0))
avg_queue_wait_ms += float(snapshot.get("avg_queue_wait_ms", 0.0)) * tasks
total_tasks += tasks
pending += int(snapshot.get("pending", 0))
pending_peak = max(pending_peak, int(snapshot.get("pending_peak", 0)))
total_batches += int(snapshot.get("total_batches", 0))
total_rows += int(snapshot.get("total_rows", 0))
batch_rows_peak = max(batch_rows_peak, int(snapshot.get("batch_rows_peak", 0)))
batch_requests_peak = max(batch_requests_peak, int(snapshot.get("batch_requests_peak", 0)))
return {
"worker_count": int(self.worker_count),
"pending": int(pending),
"pending_peak": int(pending_peak),
"total_batches": int(total_batches),
"total_tasks": int(total_tasks),
"total_rows": int(total_rows),
"avg_queue_wait_ms": float(avg_queue_wait_ms / total_tasks) if total_tasks > 0 else 0.0,
"batch_rows_peak": int(batch_rows_peak),
"batch_requests_peak": int(batch_requests_peak),
"shards": shard_snapshots,
}
class G2PWCudaConverter(_G2PWBaseOnnxConverter):
def __init__(
self,
model_dir: str = "G2PWModel/",
style: str = "bopomofo",
model_source: str = None,
enable_non_tradional_chinese: bool = False,
):
super().__init__(
model_dir=model_dir,
style=style,
model_source=model_source,
enable_non_tradional_chinese=enable_non_tradional_chinese,
)
self.runtime = G2PWRuntimePool()
self.backend = "cuda"
primary_runtime = self.runtime.shards[0]
self.device = f"cuda:{primary_runtime.device_ordinal}"
self.checkpoint_path = str(primary_runtime.weights_path)
self.providers = ["g2pw-cu"]
def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]:
probs = self.runtime.run(model_input)
preds = np.argmax(probs, axis=1).tolist()
confidences = probs[np.arange(len(preds)), preds].astype(np.float32, copy=False).tolist()
return [self.labels[pred] for pred in preds], confidences
def _predict_with_profile(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float], Dict[str, float]]:
started = time.perf_counter()
probs, runtime_profile = self.runtime.run_with_profile(model_input)
preds = np.argmax(probs, axis=1).tolist()
confidences = probs[np.arange(len(preds)), preds].astype(np.float32, copy=False).tolist()
profile = dict(runtime_profile)
profile["g2pw_runtime_total_ms"] = float((time.perf_counter() - started) * 1000.0)
profile["g2pw_predict_ms"] = float(profile["g2pw_runtime_total_ms"])
return [self.labels[pred] for pred in preds], confidences, profile
def snapshot(self) -> Dict[str, float | int | bool]:
return dict(self.runtime.snapshot())

View File

@ -8,6 +8,7 @@ from pypinyin.core import Pinyin, Style
from pypinyin.seg.simpleseg import simple_seg
from pypinyin.converter import UltimateConverter
from pypinyin.contrib.tone_convert import to_tone
from .cuda_api import G2PWCudaConverter
from .onnx_api import G2PWOnnxConverter
current_file_path = os.path.dirname(__file__)
@ -27,12 +28,36 @@ class G2PWPinyin(Pinyin):
tone_sandhi=False,
**kwargs,
):
self._g2pw = G2PWOnnxConverter(
model_dir=model_dir,
style="pinyin",
model_source=model_source,
enable_non_tradional_chinese=enable_non_tradional_chinese,
)
backend = os.environ.get("GPTSOVITS_G2PW_BACKEND", "cuda").strip().lower()
last_error = None
self._g2pw = None
if backend in {"cuda", "auto"}:
try:
self._g2pw = G2PWCudaConverter(
model_dir=model_dir,
style="pinyin",
model_source=model_source,
enable_non_tradional_chinese=enable_non_tradional_chinese,
)
except Exception as exc:
last_error = exc
strict_mode = os.environ.get("GPTSOVITS_G2PW_CUDA_STRICT", "0").strip().lower() in {
"1",
"true",
"yes",
"on",
}
if backend == "cuda" and strict_mode:
raise
if self._g2pw is None:
self._g2pw = G2PWOnnxConverter(
model_dir=model_dir,
style="pinyin",
model_source=model_source,
enable_non_tradional_chinese=enable_non_tradional_chinese,
)
if last_error is not None:
print(f"[g2pw] cuda backend unavailable, fallback to onnx: {last_error}")
self._converter = Converter(
self._g2pw,
v_to_u=v_to_u,

View File

@ -0,0 +1,183 @@
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include "g2pw/runtime.h"
namespace {
struct G2PWRuntimeHandle {
std::unique_ptr<g2pw::Runtime> runtime;
std::string last_error;
int num_labels = 0;
};
void SetError(G2PWRuntimeHandle* handle, const g2pw::Status& status) {
if (handle == nullptr) {
return;
}
handle->last_error = status.message;
}
g2pw::RuntimeConfig BuildConfig(
int device_ordinal,
int max_batch_size,
int max_seq_len,
int full_graph_cache_limit,
int tail_graph_cache_limit,
int allow_tensor_cores,
int use_cublaslt_bias_epilogue,
int enable_profiling,
int enable_cuda_graph,
int dump_graph_cache_stats,
int gemm_precision) {
g2pw::RuntimeConfig config{};
config.device_ordinal = device_ordinal;
config.max_batch_size = max_batch_size;
config.max_seq_len = max_seq_len;
config.full_graph_cache_limit = full_graph_cache_limit;
config.tail_graph_cache_limit = tail_graph_cache_limit;
config.allow_tensor_cores = allow_tensor_cores != 0;
config.use_cublaslt_bias_epilogue = use_cublaslt_bias_epilogue != 0;
config.enable_profiling = enable_profiling != 0;
config.enable_cuda_graph = enable_cuda_graph != 0;
config.dump_graph_cache_stats = dump_graph_cache_stats != 0;
switch (gemm_precision) {
case 1:
config.gemm_precision = g2pw::GemmPrecision::kFp16;
break;
case 2:
config.gemm_precision = g2pw::GemmPrecision::kBf16;
break;
default:
config.gemm_precision = g2pw::GemmPrecision::kFp32;
break;
}
return config;
}
} // namespace
extern "C" {
void* g2pw_runtime_create(
const char* manifest_path,
const char* binary_path,
int device_ordinal,
int max_batch_size,
int max_seq_len,
int full_graph_cache_limit,
int tail_graph_cache_limit,
int allow_tensor_cores,
int use_cublaslt_bias_epilogue,
int enable_profiling,
int enable_cuda_graph,
int dump_graph_cache_stats,
int gemm_precision) {
auto* handle = new G2PWRuntimeHandle();
try {
if (manifest_path == nullptr || binary_path == nullptr) {
handle->last_error = "manifest_path and binary_path must be non-null";
return handle;
}
g2pw::RuntimeConfig config = BuildConfig(
device_ordinal,
max_batch_size,
max_seq_len,
full_graph_cache_limit,
tail_graph_cache_limit,
allow_tensor_cores,
use_cublaslt_bias_epilogue,
enable_profiling,
enable_cuda_graph,
dump_graph_cache_stats,
gemm_precision);
g2pw::Status status = g2pw::Runtime::Create(
config,
std::string(manifest_path),
std::string(binary_path),
&handle->runtime);
if (!status.ok()) {
SetError(handle, status);
return handle;
}
handle->num_labels = handle->runtime != nullptr ? handle->runtime->weights().manifest().num_labels : 0;
handle->last_error.clear();
return handle;
} catch (const std::exception& exc) {
handle->last_error = exc.what();
return handle;
} catch (...) {
handle->last_error = "unknown exception";
return handle;
}
}
void g2pw_runtime_destroy(void* raw_handle) {
auto* handle = static_cast<G2PWRuntimeHandle*>(raw_handle);
delete handle;
}
const char* g2pw_runtime_last_error(void* raw_handle) {
auto* handle = static_cast<G2PWRuntimeHandle*>(raw_handle);
if (handle == nullptr) {
return "invalid runtime handle";
}
return handle->last_error.c_str();
}
int g2pw_runtime_num_labels(void* raw_handle) {
auto* handle = static_cast<G2PWRuntimeHandle*>(raw_handle);
if (handle == nullptr || handle->runtime == nullptr) {
return 0;
}
return handle->num_labels;
}
int g2pw_runtime_run(
void* raw_handle,
const std::int64_t* input_ids,
const std::int64_t* token_type_ids,
const std::int64_t* attention_mask,
const float* phoneme_mask,
const std::int64_t* char_ids,
const std::int64_t* position_ids,
std::int32_t batch_size,
std::int32_t seq_len,
float* probs) {
auto* handle = static_cast<G2PWRuntimeHandle*>(raw_handle);
if (handle == nullptr || handle->runtime == nullptr) {
return static_cast<int>(g2pw::StatusCode::kInvalidArgument);
}
try {
g2pw::InferenceInputs inputs{};
inputs.input_ids = input_ids;
inputs.token_type_ids = token_type_ids;
inputs.attention_mask = attention_mask;
inputs.phoneme_mask = phoneme_mask;
inputs.char_ids = char_ids;
inputs.position_ids = position_ids;
inputs.batch_size = batch_size;
inputs.seq_len = seq_len;
g2pw::InferenceOutputs outputs{};
outputs.probs = probs;
const g2pw::Status status = handle->runtime->Run(inputs, outputs);
if (!status.ok()) {
SetError(handle, status);
return static_cast<int>(status.code);
}
handle->last_error.clear();
return static_cast<int>(g2pw::StatusCode::kOk);
} catch (const std::exception& exc) {
handle->last_error = exc.what();
return static_cast<int>(g2pw::StatusCode::kInternalError);
} catch (...) {
handle->last_error = "unknown exception";
return static_cast<int>(g2pw::StatusCode::kInternalError);
}
}
}

View File

@ -3,6 +3,7 @@
import json
import os
import time
import warnings
import zipfile
from typing import Any, Dict, List, Tuple
@ -71,6 +72,23 @@ def _find_first_existing_file(*paths: str) -> str:
raise FileNotFoundError(f"Files not found: {paths}")
def _resolve_tokenizer_source(model_source: str | None) -> str:
candidate_paths = []
if model_source:
candidate_paths.append(model_source)
repo_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
candidate_paths.extend(
[
os.path.join(repo_root, "pretrained_models", "g2pw-chinese"),
os.path.join(repo_root, "pretrained_models", "chinese-roberta-wwm-ext-large"),
]
)
for candidate in candidate_paths:
if candidate and os.path.exists(candidate):
return candidate
return model_source or "bert-base-chinese"
def download_and_decompress(model_dir: str = "G2PWModel/"):
if not os.path.exists(model_dir):
parent_directory = os.path.dirname(model_dir)
@ -106,9 +124,9 @@ class _G2PWBaseOnnxConverter:
self.model_dir = download_and_decompress(model_dir)
self.config = load_config(config_path=os.path.join(self.model_dir, "config.py"), use_default=True)
self.model_source = model_source if model_source else self.config.model_source
self.model_source = _resolve_tokenizer_source(model_source if model_source else self.config.model_source)
self.enable_opencc = enable_non_tradional_chinese
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source, local_files_only=True)
polyphonic_chars_path = os.path.join(self.model_dir, "POLYPHONIC_CHARS.txt")
monophonic_chars_path = os.path.join(self.model_dir, "MONOPHONIC_CHARS.txt")
@ -200,6 +218,10 @@ class _G2PWBaseOnnxConverter:
return None
def __call__(self, sentences: List[str]) -> List[List[str]]:
results, _profile = self.predict_sentences_with_profile(sentences)
return results
def predict_sentences_with_profile(self, sentences: List[str]) -> Tuple[List[List[str]], Dict[str, float]]:
if isinstance(sentences, str):
sentences = [sentences]
@ -213,7 +235,7 @@ class _G2PWBaseOnnxConverter:
texts, model_query_ids, result_query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
if len(texts) == 0:
return partial_results
return partial_results, {}
model_input = prepare_onnx_input(
tokenizer=self.tokenizer,
@ -229,12 +251,21 @@ class _G2PWBaseOnnxConverter:
)
if not model_input:
return partial_results
return partial_results, {}
predict_profile: Dict[str, float] = {}
if self.enable_sentence_dedup:
preds, _confidences = self._predict_with_sentence_dedup(model_input=model_input, texts=texts)
preds, _confidences, predict_profile = self._predict_with_sentence_dedup_profiled(
model_input=model_input,
texts=texts,
)
else:
preds, _confidences = self._predict(model_input=model_input)
if hasattr(self, "_predict_with_profile"):
preds, _confidences, predict_profile = self._predict_with_profile(model_input=model_input)
else:
predict_started = time.perf_counter()
preds, _confidences = self._predict(model_input=model_input)
predict_profile["g2pw_predict_ms"] = float((time.perf_counter() - predict_started) * 1000.0)
if self.config.use_char_phoneme:
preds = [pred.split(" ")[1] for pred in preds]
@ -243,7 +274,7 @@ class _G2PWBaseOnnxConverter:
for sent_id, query_id, pred in zip(sent_ids, result_query_ids, preds):
results[sent_id][query_id] = self.style_convert_func(pred)
return results
return results, predict_profile
def _prepare_data(
self, sentences: List[str]
@ -314,6 +345,52 @@ class _G2PWBaseOnnxConverter:
return preds, confidences
def _predict_with_sentence_dedup_profiled(
self,
model_input: Dict[str, Any],
texts: List[str],
) -> Tuple[List[str], List[float], Dict[str, float]]:
if len(texts) <= 1:
if hasattr(self, "_predict_with_profile"):
return self._predict_with_profile(model_input=model_input)
predict_started = time.perf_counter()
preds, confidences = self._predict(model_input=model_input)
return preds, confidences, {"g2pw_predict_ms": float((time.perf_counter() - predict_started) * 1000.0)}
grouped_indices: Dict[str, List[int]] = {}
for idx, text in enumerate(texts):
grouped_indices.setdefault(text, []).append(idx)
if all(len(indices) == 1 for indices in grouped_indices.values()):
if hasattr(self, "_predict_with_profile"):
return self._predict_with_profile(model_input=model_input)
predict_started = time.perf_counter()
preds, confidences = self._predict(model_input=model_input)
return preds, confidences, {"g2pw_predict_ms": float((time.perf_counter() - predict_started) * 1000.0)}
preds: List[str] = [""] * len(texts)
confidences: List[float] = [0.0] * len(texts)
merged_profile: Dict[str, float] = {}
for indices in grouped_indices.values():
group_input = {name: value[indices] for name, value in model_input.items()}
if len(indices) > 1:
for name in ("input_ids", "token_type_ids", "attention_masks"):
group_input[name] = group_input[name][:1]
if hasattr(self, "_predict_with_profile"):
group_preds, group_confidences, group_profile = self._predict_with_profile(model_input=group_input)
for key, value in dict(group_profile or {}).items():
merged_profile[key] = float(merged_profile.get(key, 0.0)) + float(value)
else:
predict_started = time.perf_counter()
group_preds, group_confidences = self._predict(model_input=group_input)
merged_profile["g2pw_predict_ms"] = float(
merged_profile.get("g2pw_predict_ms", 0.0) + (time.perf_counter() - predict_started) * 1000.0
)
for output_idx, pred, confidence in zip(indices, group_preds, group_confidences):
preds[output_idx] = pred
confidences[output_idx] = confidence
return preds, confidences, merged_profile
class G2PWOnnxConverter(_G2PWBaseOnnxConverter):
def __init__(