mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-09 11:18:11 +08:00
Implement G2PW processing enhancements in TTS framework
Add support for G2PW processing in the TTS system by introducing new methods and classes for handling G2PW segments. Update PrepareCoordinator to manage G2PW worker threads and integrate G2PW profiling into the existing framework. Enhance text preprocessing to identify segments requiring G2PW and streamline the resolution of these segments. This update improves the overall performance and maintainability of the TTS system by optimizing the handling of Chinese text processing.
This commit is contained in:
parent
5cf68a91d3
commit
17cb2e5acf
@ -529,6 +529,7 @@ class TTS:
|
||||
self.bert_model,
|
||||
self.bert_tokenizer,
|
||||
self.configs.device,
|
||||
version=self.configs.version,
|
||||
bert_stage_limiter=self.prepare_bert_stage_limiter,
|
||||
bert_batch_worker=self.prepare_bert_batch_worker,
|
||||
)
|
||||
@ -558,6 +559,16 @@ class TTS:
|
||||
return None
|
||||
|
||||
def snapshot_prepare_runtime_components(self) -> dict:
|
||||
g2pw_runtime = None
|
||||
try:
|
||||
from text import chinese2
|
||||
|
||||
g2pw_instance = getattr(chinese2, "g2pw", None)
|
||||
g2pw_backend = None if g2pw_instance is None else getattr(g2pw_instance, "_g2pw", None)
|
||||
if g2pw_backend is not None and hasattr(g2pw_backend, "snapshot"):
|
||||
g2pw_runtime = dict(g2pw_backend.snapshot())
|
||||
except Exception:
|
||||
g2pw_runtime = None
|
||||
return {
|
||||
"text_cpu": {
|
||||
"workers": int(self.prepare_text_cpu_workers),
|
||||
@ -587,6 +598,7 @@ class TTS:
|
||||
"text_preprocessor": (
|
||||
None if self.text_preprocessor is None or not hasattr(self.text_preprocessor, "snapshot") else self.text_preprocessor.snapshot()
|
||||
),
|
||||
"g2pw": g2pw_runtime,
|
||||
}
|
||||
|
||||
def _build_text_cpu_admission_state(self) -> dict:
|
||||
@ -1204,6 +1216,9 @@ class TTS:
|
||||
def prepare_text_segments(self, text: str, language: str):
|
||||
return self.text_preprocessor.preprocess_text_segments(text, language, self.configs.version)
|
||||
|
||||
def resolve_g2pw_segments(self, prepared_segments, profile: dict | None = None):
|
||||
return self.text_preprocessor.resolve_g2pw_segments(prepared_segments, profile=profile)
|
||||
|
||||
def build_text_features_from_segments(self, prepared_segments, profile: dict | None = None):
|
||||
return self.text_preprocessor.build_phones_and_bert_from_segments(prepared_segments, profile=profile)
|
||||
|
||||
|
||||
@ -101,6 +101,7 @@ class PreparedTextSegment:
|
||||
phones: List[int]
|
||||
word2ph: Optional[List[int]]
|
||||
norm_text: str
|
||||
needs_g2pw: bool = False
|
||||
|
||||
|
||||
class TextPreprocessor:
|
||||
@ -109,12 +110,14 @@ class TextPreprocessor:
|
||||
bert_model: AutoModelForMaskedLM,
|
||||
tokenizer: AutoTokenizer,
|
||||
device: torch.device,
|
||||
version: str = "v2",
|
||||
bert_stage_limiter: StageLimiter | None = None,
|
||||
bert_batch_worker: PrepareBertBatchWorker | None = None,
|
||||
):
|
||||
self.bert_model = bert_model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.version = str(version)
|
||||
self.bert_stage_limiter = bert_stage_limiter
|
||||
self.bert_batch_worker = bert_batch_worker
|
||||
|
||||
@ -261,15 +264,66 @@ class TextPreprocessor:
|
||||
phones=list(payload["phones"]),
|
||||
word2ph=None if payload["word2ph"] is None else list(payload["word2ph"]),
|
||||
norm_text=str(payload["norm_text"]),
|
||||
needs_g2pw=bool(payload.get("needs_g2pw", False)),
|
||||
)
|
||||
for payload in payloads
|
||||
]
|
||||
|
||||
def resolve_g2pw_segments(
|
||||
self,
|
||||
prepared_segments: List[PreparedTextSegment],
|
||||
profile: Dict | None = None,
|
||||
) -> List[PreparedTextSegment]:
|
||||
zh_indices = [index for index, segment in enumerate(prepared_segments) if bool(segment.needs_g2pw)]
|
||||
if not zh_indices:
|
||||
return prepared_segments
|
||||
from text import chinese2
|
||||
|
||||
normalized_segments = [prepared_segments[index].norm_text for index in zh_indices]
|
||||
resolved_segments, g2pw_profile = chinese2.g2p_segments(normalized_segments, return_profile=True)
|
||||
self._accumulate_profile(profile, "g2pw_prepare_ms", g2pw_profile.get("g2pw_prepare_ms", 0.0))
|
||||
self._accumulate_profile(profile, "g2pw_predict_ms", g2pw_profile.get("g2pw_predict_ms", 0.0))
|
||||
self._accumulate_profile(profile, "g2pw_post_ms", g2pw_profile.get("g2pw_post_ms", 0.0))
|
||||
self._accumulate_profile(profile, "g2pw_total_ms", g2pw_profile.get("g2pw_total_ms", 0.0))
|
||||
self._accumulate_profile(profile, "g2pw_runtime_total_ms", g2pw_profile.get("g2pw_runtime_total_ms", 0.0))
|
||||
self._accumulate_profile(profile, "g2pw_runtime_queue_wait_ms", g2pw_profile.get("g2pw_runtime_queue_wait_ms", 0.0))
|
||||
self._accumulate_profile(
|
||||
profile,
|
||||
"g2pw_runtime_collect_wait_ms",
|
||||
g2pw_profile.get("g2pw_runtime_collect_wait_ms", 0.0),
|
||||
)
|
||||
self._accumulate_profile(profile, "g2pw_runtime_run_ms", g2pw_profile.get("g2pw_runtime_run_ms", 0.0))
|
||||
self._update_profile_peak(
|
||||
profile,
|
||||
"g2pw_runtime_batch_rows_peak",
|
||||
g2pw_profile.get("g2pw_runtime_batch_rows", 0.0),
|
||||
)
|
||||
self._update_profile_peak(
|
||||
profile,
|
||||
"g2pw_runtime_batch_requests_peak",
|
||||
g2pw_profile.get("g2pw_runtime_batch_requests", 0.0),
|
||||
)
|
||||
self._update_profile_peak(
|
||||
profile,
|
||||
"g2pw_runtime_pool_workers",
|
||||
g2pw_profile.get("g2pw_runtime_pool_workers", 0.0),
|
||||
)
|
||||
for index, (phones, word2ph, norm_text) in zip(zh_indices, resolved_segments):
|
||||
prepared_segments[index] = PreparedTextSegment(
|
||||
language=prepared_segments[index].language,
|
||||
phones=list(cleaned_text_to_sequence(phones, self.version)),
|
||||
word2ph=None if word2ph is None else list(word2ph),
|
||||
norm_text=str(norm_text),
|
||||
needs_g2pw=False,
|
||||
)
|
||||
return prepared_segments
|
||||
|
||||
def build_phones_and_bert_from_segments(
|
||||
self,
|
||||
prepared_segments: List[PreparedTextSegment],
|
||||
profile: Dict | None = None,
|
||||
) -> Tuple[list, torch.Tensor, str]:
|
||||
prepared_segments = self.resolve_g2pw_segments(prepared_segments, profile=profile)
|
||||
phones_list: List[List[int]] = []
|
||||
bert_list: List[torch.Tensor] = []
|
||||
norm_text_list: List[str] = []
|
||||
@ -402,6 +456,7 @@ class TextPreprocessor:
|
||||
prepared_segments: List[PreparedTextSegment],
|
||||
profile: Dict | None = None,
|
||||
) -> Tuple[list, torch.Tensor, str]:
|
||||
prepared_segments = self.resolve_g2pw_segments(prepared_segments, profile=profile)
|
||||
segment_jobs = self._build_async_segment_jobs(prepared_segments, profile)
|
||||
pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = []
|
||||
for segment_index, segment in enumerate(prepared_segments):
|
||||
@ -473,6 +528,8 @@ class TextPreprocessor:
|
||||
prompt_profile: Dict | None = None,
|
||||
target_profile: Dict | None = None,
|
||||
) -> Tuple[Tuple[list, torch.Tensor, str], Tuple[list, torch.Tensor, str]]:
|
||||
prompt_segments = self.resolve_g2pw_segments(prompt_segments, profile=prompt_profile)
|
||||
target_segments = self.resolve_g2pw_segments(target_segments, profile=target_profile)
|
||||
prompt_jobs = self._build_async_segment_jobs(prompt_segments, prompt_profile)
|
||||
target_jobs = self._build_async_segment_jobs(target_segments, target_profile)
|
||||
pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = []
|
||||
|
||||
@ -120,6 +120,15 @@ class PrepareCoordinator:
|
||||
max_workers=self.text_feature_workers,
|
||||
thread_name_prefix="prepare-text-feature",
|
||||
)
|
||||
g2pw_default_workers = max(8, int(getattr(tts, "prepare_text_cpu_workers", 8) or 8))
|
||||
self.g2pw_workers = max(
|
||||
1,
|
||||
int(os.environ.get("GPTSOVITS_PREPARE_G2PW_WORKERS", str(g2pw_default_workers))),
|
||||
)
|
||||
self.g2pw_executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=self.g2pw_workers,
|
||||
thread_name_prefix="prepare-g2pw",
|
||||
)
|
||||
ref_audio_default_workers = max(1, int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4")))
|
||||
self.ref_audio_workers = max(
|
||||
1,
|
||||
@ -130,12 +139,17 @@ class PrepareCoordinator:
|
||||
thread_name_prefix="prepare-ref-audio",
|
||||
)
|
||||
text_cpu_gate_default = max(0, int(getattr(tts, "prepare_text_cpu_workers", 0) or 0))
|
||||
g2pw_gate_default = max(0, int(self.g2pw_workers))
|
||||
text_feature_gate_default = max(0, int(self.text_feature_workers))
|
||||
ref_audio_gate_default = max(0, int(self.ref_audio_workers))
|
||||
self.text_cpu_gate = AsyncStageGate(
|
||||
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_INFLIGHT", str(text_cpu_gate_default))),
|
||||
poll_ms=gate_poll_ms,
|
||||
)
|
||||
self.g2pw_gate = AsyncStageGate(
|
||||
int(os.environ.get("GPTSOVITS_PREPARE_G2PW_MAX_INFLIGHT", str(g2pw_gate_default))),
|
||||
poll_ms=gate_poll_ms,
|
||||
)
|
||||
self.text_feature_gate = AsyncStageGate(
|
||||
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_MAX_INFLIGHT", str(text_feature_gate_default))),
|
||||
poll_ms=gate_poll_ms,
|
||||
@ -172,6 +186,7 @@ class PrepareCoordinator:
|
||||
"peak_inflight": int(self.peak_inflight),
|
||||
"max_inflight": int(self.max_inflight),
|
||||
"text_feature_workers": int(self.text_feature_workers),
|
||||
"g2pw_workers": int(self.g2pw_workers),
|
||||
"ref_audio_workers": int(self.ref_audio_workers),
|
||||
}
|
||||
runtime_snapshot_fn = getattr(self.tts, "snapshot_prepare_runtime_components", None)
|
||||
@ -182,6 +197,7 @@ class PrepareCoordinator:
|
||||
snapshot["prepare_runtime_state"] = None
|
||||
snapshot["prepare_stage_gates"] = {
|
||||
"text_cpu": self.text_cpu_gate.snapshot(),
|
||||
"g2pw": self.g2pw_gate.snapshot(),
|
||||
"text_feature": self.text_feature_gate.snapshot(),
|
||||
"ref_audio": self.ref_audio_gate.snapshot(),
|
||||
"ref_load": self.ref_load_gate.snapshot(),
|
||||
@ -204,6 +220,11 @@ class PrepareCoordinator:
|
||||
def _prepare_text_cpu(self, text: str, language: str):
|
||||
return self.tts.prepare_text_segments(text, language)
|
||||
|
||||
def _resolve_g2pw_segments(self, prepared_segments):
|
||||
profile: Dict[str, float] = {}
|
||||
resolved_segments = self.tts.resolve_g2pw_segments(prepared_segments, profile=profile)
|
||||
return resolved_segments, profile
|
||||
|
||||
def _load_ref_audio_raw(self, ref_audio_path: str):
|
||||
return self.tts._load_ref_audio_raw(ref_audio_path)
|
||||
|
||||
@ -225,8 +246,15 @@ class PrepareCoordinator:
|
||||
dtype=(dtype if dtype is not None else None) or __import__("torch").float32,
|
||||
)
|
||||
|
||||
def _build_text_features(self, prepared_segments, language: str, cpu_run_ms: float) -> PreparedTextFeatures:
|
||||
profile: Dict[str, float] = {"cpu_preprocess_ms": float(cpu_run_ms)}
|
||||
def _build_text_features(
|
||||
self,
|
||||
prepared_segments,
|
||||
language: str,
|
||||
cpu_run_ms: float,
|
||||
base_profile: Dict[str, float] | None = None,
|
||||
) -> PreparedTextFeatures:
|
||||
profile: Dict[str, float] = dict(base_profile or {})
|
||||
profile["cpu_preprocess_ms"] = float(cpu_run_ms)
|
||||
branch_start = time.perf_counter()
|
||||
phones, bert_features, norm_text = self.tts.build_text_features_from_segments(prepared_segments, profile=profile)
|
||||
total_ms = float(cpu_run_ms + (time.perf_counter() - branch_start) * 1000.0)
|
||||
@ -291,10 +319,53 @@ class PrepareCoordinator:
|
||||
prepared_segments,
|
||||
language,
|
||||
cpu_run_ms,
|
||||
None,
|
||||
)
|
||||
finally:
|
||||
self.text_feature_gate.release()
|
||||
|
||||
async def _run_g2pw_stage(self, prepared_segments) -> ProfiledResult:
|
||||
has_pending = any(bool(getattr(segment, "needs_g2pw", False)) for segment in (prepared_segments or []))
|
||||
if not has_pending:
|
||||
submit_at = time.perf_counter()
|
||||
return ProfiledResult(
|
||||
result=prepared_segments,
|
||||
submit_at=float(submit_at),
|
||||
started_at=float(submit_at),
|
||||
finished_at=float(submit_at),
|
||||
profile={},
|
||||
)
|
||||
await self.g2pw_gate.acquire()
|
||||
try:
|
||||
profiled = await self._run_on_executor(self.g2pw_executor, self._resolve_g2pw_segments, prepared_segments)
|
||||
result, stage_profile = profiled.result
|
||||
return ProfiledResult(
|
||||
result=result,
|
||||
submit_at=float(profiled.submit_at),
|
||||
started_at=float(profiled.started_at),
|
||||
finished_at=float(profiled.finished_at),
|
||||
profile=dict(stage_profile),
|
||||
)
|
||||
finally:
|
||||
self.g2pw_gate.release()
|
||||
|
||||
async def _run_g2pw_pair_stage(self, prompt_segments, target_segments) -> tuple[ProfiledResult, ProfiledResult]:
|
||||
prompt_is_empty = len(prompt_segments or []) == 0
|
||||
target_task = asyncio.create_task(self._run_g2pw_stage(target_segments))
|
||||
if not prompt_is_empty:
|
||||
prompt_task = asyncio.create_task(self._run_g2pw_stage(prompt_segments))
|
||||
return await asyncio.gather(prompt_task, target_task)
|
||||
target_profiled = await target_task
|
||||
submit_at = time.perf_counter()
|
||||
prompt_profiled = ProfiledResult(
|
||||
result=prompt_segments,
|
||||
submit_at=float(submit_at),
|
||||
started_at=float(submit_at),
|
||||
finished_at=float(submit_at),
|
||||
profile={},
|
||||
)
|
||||
return prompt_profiled, target_profiled
|
||||
|
||||
@staticmethod
|
||||
def _estimate_text_feature_run_ms(profile: Dict[str, float]) -> float:
|
||||
return float(
|
||||
@ -310,12 +381,32 @@ class PrepareCoordinator:
|
||||
target_segments,
|
||||
prompt_cpu_run_ms: float,
|
||||
target_cpu_run_ms: float,
|
||||
prompt_base_profile: Dict[str, float] | None = None,
|
||||
target_base_profile: Dict[str, float] | None = None,
|
||||
) -> tuple[ProfiledResult, ProfiledResult]:
|
||||
prompt_is_empty = len(prompt_segments or []) == 0
|
||||
if self.text_feature_executor is not None:
|
||||
target_feature_task = asyncio.create_task(self._run_text_feature_stage(target_segments, None, target_cpu_run_ms))
|
||||
target_feature_task = asyncio.create_task(
|
||||
self._run_on_executor(
|
||||
self.text_feature_executor,
|
||||
self._build_text_features,
|
||||
target_segments,
|
||||
None,
|
||||
target_cpu_run_ms,
|
||||
target_base_profile,
|
||||
)
|
||||
)
|
||||
if not prompt_is_empty:
|
||||
prompt_feature_task = asyncio.create_task(self._run_text_feature_stage(prompt_segments, None, prompt_cpu_run_ms))
|
||||
prompt_feature_task = asyncio.create_task(
|
||||
self._run_on_executor(
|
||||
self.text_feature_executor,
|
||||
self._build_text_features,
|
||||
prompt_segments,
|
||||
None,
|
||||
prompt_cpu_run_ms,
|
||||
prompt_base_profile,
|
||||
)
|
||||
)
|
||||
return await asyncio.gather(prompt_feature_task, target_feature_task)
|
||||
target_profiled = await target_feature_task
|
||||
submit_at = time.perf_counter()
|
||||
@ -328,7 +419,8 @@ class PrepareCoordinator:
|
||||
return prompt_profiled, target_profiled
|
||||
|
||||
await self.text_feature_gate.acquire()
|
||||
target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)}
|
||||
target_profile: Dict[str, float] = dict(target_base_profile or {})
|
||||
target_profile["cpu_preprocess_ms"] = float(target_cpu_run_ms)
|
||||
submit_at = time.perf_counter()
|
||||
started_at = float(submit_at)
|
||||
try:
|
||||
@ -377,7 +469,8 @@ class PrepareCoordinator:
|
||||
target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile)
|
||||
return prompt_profiled, target_profiled
|
||||
|
||||
prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)}
|
||||
prompt_profile: Dict[str, float] = dict(prompt_base_profile or {})
|
||||
prompt_profile["cpu_preprocess_ms"] = float(prompt_cpu_run_ms)
|
||||
prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async(
|
||||
prompt_segments,
|
||||
target_segments,
|
||||
@ -589,20 +682,31 @@ class PrepareCoordinator:
|
||||
cpu_stage: PreparedCpuStage,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
try:
|
||||
text_pair_start = time.perf_counter()
|
||||
ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(cpu_stage.spec.ref_audio_path)))
|
||||
text_feature_pair_task = asyncio.create_task(
|
||||
self._run_text_feature_pair_stage(
|
||||
g2pw_pair_start = time.perf_counter()
|
||||
g2pw_pair_task = asyncio.create_task(
|
||||
self._run_g2pw_pair_stage(
|
||||
cpu_stage.prompt_cpu_profiled.result,
|
||||
cpu_stage.target_cpu_profiled.result,
|
||||
cpu_stage.prompt_cpu_profiled.run_ms,
|
||||
cpu_stage.target_cpu_profiled.run_ms,
|
||||
)
|
||||
)
|
||||
(prompt_feature_profiled, target_feature_profiled), ref_audio_profiled = await asyncio.gather(
|
||||
text_feature_pair_task,
|
||||
ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(cpu_stage.spec.ref_audio_path)))
|
||||
(prompt_g2pw_profiled, target_g2pw_profiled), ref_audio_profiled = await asyncio.gather(
|
||||
g2pw_pair_task,
|
||||
ref_audio_task,
|
||||
)
|
||||
g2pw_pair_end = time.perf_counter()
|
||||
text_pair_start = time.perf_counter()
|
||||
text_feature_pair_task = asyncio.create_task(
|
||||
self._run_text_feature_pair_stage(
|
||||
prompt_g2pw_profiled.result,
|
||||
target_g2pw_profiled.result,
|
||||
cpu_stage.prompt_cpu_profiled.run_ms,
|
||||
cpu_stage.target_cpu_profiled.run_ms,
|
||||
prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}),
|
||||
target_base_profile=dict(target_g2pw_profiled.profile or {}),
|
||||
)
|
||||
)
|
||||
prompt_feature_profiled, target_feature_profiled = await text_feature_pair_task
|
||||
text_pair_end = time.perf_counter()
|
||||
state = build_request_state_from_parts(
|
||||
tts=self.tts,
|
||||
@ -619,6 +723,17 @@ class PrepareCoordinator:
|
||||
"prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms,
|
||||
"executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0),
|
||||
"text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0),
|
||||
"g2pw_pair_ms": max(0.0, (g2pw_pair_end - g2pw_pair_start) * 1000.0),
|
||||
"prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms,
|
||||
"prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms,
|
||||
"prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)),
|
||||
"prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)),
|
||||
"prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)),
|
||||
"text_g2pw_queue_ms": target_g2pw_profiled.queue_ms,
|
||||
"text_g2pw_run_ms": target_g2pw_profiled.run_ms,
|
||||
"text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)),
|
||||
"text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)),
|
||||
"text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)),
|
||||
"prompt_text_parallel_future_wait_ms": 0.0,
|
||||
"prompt_text_parallel_future_executor_queue_ms": 0.0,
|
||||
"prompt_text_parallel_future_run_ms": 0.0,
|
||||
|
||||
@ -241,6 +241,24 @@ def build_request_state_from_parts(
|
||||
prompt_result.profile.get("bert_high_pressure_mode_peak", 0.0)
|
||||
),
|
||||
"prompt_text_bert_batch_window_ms": float(prompt_result.profile.get("bert_batch_window_ms", 0.0)),
|
||||
"prompt_text_g2pw_total_ms": float(prompt_result.profile.get("g2pw_total_ms", 0.0)),
|
||||
"prompt_text_g2pw_prepare_ms": float(prompt_result.profile.get("g2pw_prepare_ms", 0.0)),
|
||||
"prompt_text_g2pw_predict_ms": float(prompt_result.profile.get("g2pw_predict_ms", 0.0)),
|
||||
"prompt_text_g2pw_post_ms": float(prompt_result.profile.get("g2pw_post_ms", 0.0)),
|
||||
"prompt_text_g2pw_runtime_total_ms": float(prompt_result.profile.get("g2pw_runtime_total_ms", 0.0)),
|
||||
"prompt_text_g2pw_runtime_queue_wait_ms": float(
|
||||
prompt_result.profile.get("g2pw_runtime_queue_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_text_g2pw_runtime_collect_wait_ms": float(
|
||||
prompt_result.profile.get("g2pw_runtime_collect_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_text_g2pw_runtime_run_ms": float(prompt_result.profile.get("g2pw_runtime_run_ms", 0.0)),
|
||||
"prompt_text_g2pw_runtime_batch_rows_peak": float(
|
||||
prompt_result.profile.get("g2pw_runtime_batch_rows_peak", 0.0)
|
||||
),
|
||||
"prompt_text_g2pw_runtime_batch_requests_peak": float(
|
||||
prompt_result.profile.get("g2pw_runtime_batch_requests_peak", 0.0)
|
||||
),
|
||||
"prompt_text_parallel_future_wait_ms": 0.0,
|
||||
"prompt_text_parallel_future_executor_queue_ms": 0.0,
|
||||
"prompt_text_parallel_future_run_ms": float(prompt_result.total_ms),
|
||||
@ -267,6 +285,18 @@ def build_request_state_from_parts(
|
||||
),
|
||||
"text_bert_high_pressure_mode_peak": float(target_result.profile.get("bert_high_pressure_mode_peak", 0.0)),
|
||||
"text_bert_batch_window_ms": float(target_result.profile.get("bert_batch_window_ms", 0.0)),
|
||||
"text_g2pw_total_ms": float(target_result.profile.get("g2pw_total_ms", 0.0)),
|
||||
"text_g2pw_prepare_ms": float(target_result.profile.get("g2pw_prepare_ms", 0.0)),
|
||||
"text_g2pw_predict_ms": float(target_result.profile.get("g2pw_predict_ms", 0.0)),
|
||||
"text_g2pw_post_ms": float(target_result.profile.get("g2pw_post_ms", 0.0)),
|
||||
"text_g2pw_runtime_total_ms": float(target_result.profile.get("g2pw_runtime_total_ms", 0.0)),
|
||||
"text_g2pw_runtime_queue_wait_ms": float(target_result.profile.get("g2pw_runtime_queue_wait_ms", 0.0)),
|
||||
"text_g2pw_runtime_collect_wait_ms": float(target_result.profile.get("g2pw_runtime_collect_wait_ms", 0.0)),
|
||||
"text_g2pw_runtime_run_ms": float(target_result.profile.get("g2pw_runtime_run_ms", 0.0)),
|
||||
"text_g2pw_runtime_batch_rows_peak": float(target_result.profile.get("g2pw_runtime_batch_rows_peak", 0.0)),
|
||||
"text_g2pw_runtime_batch_requests_peak": float(
|
||||
target_result.profile.get("g2pw_runtime_batch_requests_peak", 0.0)
|
||||
),
|
||||
"text_feature_pair_ms": float(max(prompt_result.total_ms, target_result.total_ms)),
|
||||
"text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)),
|
||||
"audio_load_ms": audio_load_ms,
|
||||
|
||||
@ -8,6 +8,7 @@ sys.path.append(now_dir)
|
||||
|
||||
from text.LangSegmenter import LangSegmenter
|
||||
from text import cleaned_text_to_sequence
|
||||
from text import chinese2
|
||||
from text.cleaner import clean_text
|
||||
|
||||
|
||||
@ -83,16 +84,27 @@ def preprocess_text_segments_payload(
|
||||
payloads: List[PreparedTextSegmentPayload] = []
|
||||
total_phones_len = 0
|
||||
for segment_text, segment_lang in zip(textlist, langlist):
|
||||
phones, word2ph, norm_text = clean_text_segment(segment_text, segment_lang, version)
|
||||
normalized_language = segment_lang.replace("all_", "")
|
||||
if normalized_language == "zh":
|
||||
norm_text = chinese2.text_normalize(segment_text)
|
||||
phones = []
|
||||
word2ph = None
|
||||
needs_g2pw = True
|
||||
estimated_phones_len = max(0, len(norm_text) * 2)
|
||||
else:
|
||||
phones, word2ph, norm_text = clean_text_segment(segment_text, segment_lang, version)
|
||||
needs_g2pw = False
|
||||
estimated_phones_len = len(phones)
|
||||
payloads.append(
|
||||
{
|
||||
"language": segment_lang.replace("all_", ""),
|
||||
"language": normalized_language,
|
||||
"phones": phones,
|
||||
"word2ph": word2ph,
|
||||
"norm_text": norm_text,
|
||||
"needs_g2pw": needs_g2pw,
|
||||
}
|
||||
)
|
||||
total_phones_len += len(phones)
|
||||
total_phones_len += int(estimated_phones_len)
|
||||
|
||||
if not final and total_phones_len < 6:
|
||||
return preprocess_text_segments_payload("." + text, language, version, final=True)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
|
||||
import cn2an
|
||||
from pypinyin import lazy_pinyin, Style
|
||||
@ -77,6 +78,205 @@ def g2p(text):
|
||||
return phones, word2ph
|
||||
|
||||
|
||||
def _prepare_g2p_segments(segments):
|
||||
prepared_segments = []
|
||||
batch_inputs = []
|
||||
for segment in segments:
|
||||
processed_segment = re.sub("[a-zA-Z]+", "", segment)
|
||||
seg_cut = psg.lcut(processed_segment)
|
||||
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
|
||||
prepared_segments.append(
|
||||
{
|
||||
"segment": processed_segment,
|
||||
"seg_cut": seg_cut,
|
||||
}
|
||||
)
|
||||
if processed_segment:
|
||||
batch_inputs.append(processed_segment)
|
||||
return prepared_segments, batch_inputs
|
||||
|
||||
|
||||
def _build_segment_from_g2pw(segment: str, seg_cut, pinyins):
|
||||
phones_list = []
|
||||
word2ph = []
|
||||
initials = []
|
||||
finals = []
|
||||
pre_word_length = 0
|
||||
for word, pos in seg_cut:
|
||||
sub_initials = []
|
||||
sub_finals = []
|
||||
now_word_length = pre_word_length + len(word)
|
||||
|
||||
if pos == "eng":
|
||||
pre_word_length = now_word_length
|
||||
continue
|
||||
|
||||
word_pinyins = pinyins[pre_word_length:now_word_length]
|
||||
word_pinyins = correct_pronunciation(word, word_pinyins)
|
||||
|
||||
for pinyin in word_pinyins:
|
||||
if pinyin[0].isalpha():
|
||||
sub_initials.append(to_initials(pinyin))
|
||||
sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True))
|
||||
else:
|
||||
sub_initials.append(pinyin)
|
||||
sub_finals.append(pinyin)
|
||||
|
||||
pre_word_length = now_word_length
|
||||
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
|
||||
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
|
||||
initials.append(sub_initials)
|
||||
finals.append(sub_finals)
|
||||
|
||||
initials = sum(initials, [])
|
||||
finals = sum(finals, [])
|
||||
for c, v in zip(initials, finals):
|
||||
raw_pinyin = c + v
|
||||
if c == v:
|
||||
assert c in punctuation
|
||||
phone = [c]
|
||||
word2ph.append(1)
|
||||
else:
|
||||
v_without_tone = v[:-1]
|
||||
tone = v[-1]
|
||||
|
||||
pinyin = c + v_without_tone
|
||||
assert tone in "12345"
|
||||
|
||||
if c:
|
||||
v_rep_map = {
|
||||
"uei": "ui",
|
||||
"iou": "iu",
|
||||
"uen": "un",
|
||||
}
|
||||
if v_without_tone in v_rep_map.keys():
|
||||
pinyin = c + v_rep_map[v_without_tone]
|
||||
else:
|
||||
pinyin_rep_map = {
|
||||
"ing": "ying",
|
||||
"i": "yi",
|
||||
"in": "yin",
|
||||
"u": "wu",
|
||||
}
|
||||
if pinyin in pinyin_rep_map.keys():
|
||||
pinyin = pinyin_rep_map[pinyin]
|
||||
else:
|
||||
single_rep_map = {
|
||||
"v": "yu",
|
||||
"e": "e",
|
||||
"i": "y",
|
||||
"u": "w",
|
||||
}
|
||||
if pinyin[0] in single_rep_map.keys():
|
||||
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
|
||||
|
||||
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, segment, raw_pinyin)
|
||||
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
|
||||
new_v = new_v + tone
|
||||
phone = [new_c, new_v]
|
||||
word2ph.append(len(phone))
|
||||
|
||||
phones_list += phone
|
||||
return phones_list, word2ph
|
||||
|
||||
|
||||
def _build_segment_without_g2pw(segment: str, seg_cut):
|
||||
initials = []
|
||||
finals = []
|
||||
for word, pos in seg_cut:
|
||||
if pos == "eng":
|
||||
continue
|
||||
sub_initials, sub_finals = _get_initials_finals(word)
|
||||
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
|
||||
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
|
||||
initials.append(sub_initials)
|
||||
finals.append(sub_finals)
|
||||
phones_list = []
|
||||
word2ph = []
|
||||
for c, v in zip(sum(initials, []), sum(finals, [])):
|
||||
raw_pinyin = c + v
|
||||
if c == v:
|
||||
assert c in punctuation
|
||||
phone = [c]
|
||||
word2ph.append(1)
|
||||
else:
|
||||
v_without_tone = v[:-1]
|
||||
tone = v[-1]
|
||||
pinyin = c + v_without_tone
|
||||
assert tone in "12345"
|
||||
if c:
|
||||
v_rep_map = {"uei": "ui", "iou": "iu", "uen": "un"}
|
||||
if v_without_tone in v_rep_map:
|
||||
pinyin = c + v_rep_map[v_without_tone]
|
||||
else:
|
||||
pinyin_rep_map = {"ing": "ying", "i": "yi", "in": "yin", "u": "wu"}
|
||||
if pinyin in pinyin_rep_map:
|
||||
pinyin = pinyin_rep_map[pinyin]
|
||||
else:
|
||||
single_rep_map = {"v": "yu", "e": "e", "i": "y", "u": "w"}
|
||||
if pinyin[0] in single_rep_map:
|
||||
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
|
||||
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, segment, raw_pinyin)
|
||||
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
|
||||
new_v = new_v + tone
|
||||
phone = [new_c, new_v]
|
||||
word2ph.append(len(phone))
|
||||
phones_list += phone
|
||||
return phones_list, word2ph
|
||||
|
||||
|
||||
def g2p_segments(segments, return_profile: bool = False):
|
||||
prepare_start = time.perf_counter()
|
||||
prepared_segments, batch_inputs = _prepare_g2p_segments(segments)
|
||||
profile = {
|
||||
"g2pw_prepare_ms": 0.0,
|
||||
"g2pw_predict_ms": 0.0,
|
||||
"g2pw_post_ms": 0.0,
|
||||
"g2pw_runtime_total_ms": 0.0,
|
||||
"g2pw_runtime_queue_wait_ms": 0.0,
|
||||
"g2pw_runtime_collect_wait_ms": 0.0,
|
||||
"g2pw_runtime_run_ms": 0.0,
|
||||
"g2pw_runtime_batch_rows": 0.0,
|
||||
"g2pw_runtime_batch_requests": 0.0,
|
||||
"g2pw_runtime_pool_workers": 0.0,
|
||||
"g2pw_runtime_shard_index": 0.0,
|
||||
}
|
||||
profile["g2pw_prepare_ms"] = float((time.perf_counter() - prepare_start) * 1000.0)
|
||||
if is_g2pw and batch_inputs:
|
||||
converter = g2pw._g2pw
|
||||
if hasattr(converter, "predict_sentences_with_profile"):
|
||||
g2pw_batch_results, predict_profile = converter.predict_sentences_with_profile(batch_inputs)
|
||||
for key, value in dict(predict_profile or {}).items():
|
||||
profile[key] = float(value)
|
||||
else:
|
||||
predict_start = time.perf_counter()
|
||||
g2pw_batch_results = converter(batch_inputs)
|
||||
profile["g2pw_predict_ms"] = float((time.perf_counter() - predict_start) * 1000.0)
|
||||
else:
|
||||
g2pw_batch_results = []
|
||||
post_start = time.perf_counter()
|
||||
results = []
|
||||
batch_cursor = 0
|
||||
for item in prepared_segments:
|
||||
segment = item["segment"]
|
||||
if not segment:
|
||||
results.append(([], [], segment))
|
||||
continue
|
||||
if not is_g2pw:
|
||||
phones, word2ph = _build_segment_without_g2pw(segment, item["seg_cut"])
|
||||
results.append((phones, word2ph, segment))
|
||||
continue
|
||||
pinyins = g2pw_batch_results[batch_cursor]
|
||||
batch_cursor += 1
|
||||
phones, word2ph = _build_segment_from_g2pw(segment, item["seg_cut"], pinyins)
|
||||
results.append((phones, word2ph, segment))
|
||||
profile["g2pw_post_ms"] = float((time.perf_counter() - post_start) * 1000.0)
|
||||
profile["g2pw_total_ms"] = float(profile["g2pw_prepare_ms"] + profile["g2pw_predict_ms"] + profile["g2pw_post_ms"])
|
||||
if return_profile:
|
||||
return results, profile
|
||||
return results
|
||||
|
||||
|
||||
def _get_initials_finals(word):
|
||||
initials = []
|
||||
finals = []
|
||||
@ -180,125 +380,9 @@ def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) ->
|
||||
def _g2p(segments):
|
||||
phones_list = []
|
||||
word2ph = []
|
||||
g2pw_batch_results = []
|
||||
g2pw_batch_cursor = 0
|
||||
processed_segments = [re.sub("[a-zA-Z]+", "", seg) for seg in segments]
|
||||
if is_g2pw:
|
||||
batch_inputs = [seg for seg in processed_segments if seg]
|
||||
g2pw_batch_results = g2pw._g2pw(batch_inputs) if batch_inputs else []
|
||||
|
||||
for seg in processed_segments:
|
||||
pinyins = []
|
||||
seg_cut = psg.lcut(seg)
|
||||
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
|
||||
initials = []
|
||||
finals = []
|
||||
|
||||
if not is_g2pw:
|
||||
for word, pos in seg_cut:
|
||||
if pos == "eng":
|
||||
continue
|
||||
sub_initials, sub_finals = _get_initials_finals(word)
|
||||
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
|
||||
# 儿化
|
||||
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
|
||||
initials.append(sub_initials)
|
||||
finals.append(sub_finals)
|
||||
# assert len(sub_initials) == len(sub_finals) == len(word)
|
||||
initials = sum(initials, [])
|
||||
finals = sum(finals, [])
|
||||
print("pypinyin结果", initials, finals)
|
||||
else:
|
||||
# g2pw采用整句推理(批量推理,逐句取结果)
|
||||
if seg:
|
||||
pinyins = g2pw_batch_results[g2pw_batch_cursor]
|
||||
g2pw_batch_cursor += 1
|
||||
|
||||
pre_word_length = 0
|
||||
for word, pos in seg_cut:
|
||||
sub_initials = []
|
||||
sub_finals = []
|
||||
now_word_length = pre_word_length + len(word)
|
||||
|
||||
if pos == "eng":
|
||||
pre_word_length = now_word_length
|
||||
continue
|
||||
|
||||
word_pinyins = pinyins[pre_word_length:now_word_length]
|
||||
|
||||
# 多音字消歧
|
||||
word_pinyins = correct_pronunciation(word, word_pinyins)
|
||||
|
||||
for pinyin in word_pinyins:
|
||||
if pinyin[0].isalpha():
|
||||
sub_initials.append(to_initials(pinyin))
|
||||
sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True))
|
||||
else:
|
||||
sub_initials.append(pinyin)
|
||||
sub_finals.append(pinyin)
|
||||
|
||||
pre_word_length = now_word_length
|
||||
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
|
||||
# 儿化
|
||||
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
|
||||
initials.append(sub_initials)
|
||||
finals.append(sub_finals)
|
||||
|
||||
initials = sum(initials, [])
|
||||
finals = sum(finals, [])
|
||||
# print("g2pw结果",initials,finals)
|
||||
|
||||
for c, v in zip(initials, finals):
|
||||
raw_pinyin = c + v
|
||||
# NOTE: post process for pypinyin outputs
|
||||
# we discriminate i, ii and iii
|
||||
if c == v:
|
||||
assert c in punctuation
|
||||
phone = [c]
|
||||
word2ph.append(1)
|
||||
else:
|
||||
v_without_tone = v[:-1]
|
||||
tone = v[-1]
|
||||
|
||||
pinyin = c + v_without_tone
|
||||
assert tone in "12345"
|
||||
|
||||
if c:
|
||||
# 多音节
|
||||
v_rep_map = {
|
||||
"uei": "ui",
|
||||
"iou": "iu",
|
||||
"uen": "un",
|
||||
}
|
||||
if v_without_tone in v_rep_map.keys():
|
||||
pinyin = c + v_rep_map[v_without_tone]
|
||||
else:
|
||||
# 单音节
|
||||
pinyin_rep_map = {
|
||||
"ing": "ying",
|
||||
"i": "yi",
|
||||
"in": "yin",
|
||||
"u": "wu",
|
||||
}
|
||||
if pinyin in pinyin_rep_map.keys():
|
||||
pinyin = pinyin_rep_map[pinyin]
|
||||
else:
|
||||
single_rep_map = {
|
||||
"v": "yu",
|
||||
"e": "e",
|
||||
"i": "y",
|
||||
"u": "w",
|
||||
}
|
||||
if pinyin[0] in single_rep_map.keys():
|
||||
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
|
||||
|
||||
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
|
||||
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
|
||||
new_v = new_v + tone
|
||||
phone = [new_c, new_v]
|
||||
word2ph.append(len(phone))
|
||||
|
||||
phones_list += phone
|
||||
for phones, item_word2ph, _segment in g2p_segments(segments):
|
||||
phones_list += phones
|
||||
word2ph += item_word2ph
|
||||
return phones_list, word2ph
|
||||
|
||||
|
||||
|
||||
670
GPT_SoVITS/text/g2pw/cuda_api.py
Normal file
670
GPT_SoVITS/text/g2pw/cuda_api.py
Normal file
@ -0,0 +1,670 @@
|
||||
import ctypes
|
||||
import fcntl
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Deque, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .onnx_api import _G2PWBaseOnnxConverter
|
||||
|
||||
|
||||
class G2PWCudaError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class G2PWBatchTask:
|
||||
model_input: Dict[str, np.ndarray]
|
||||
created_at: float = field(default_factory=time.perf_counter)
|
||||
enqueued_at: float = 0.0
|
||||
done_event: threading.Event = field(default_factory=threading.Event)
|
||||
output: np.ndarray | None = None
|
||||
profile: Dict[str, float] = field(default_factory=dict)
|
||||
error: Exception | None = None
|
||||
|
||||
|
||||
_ROOT_DIR = Path(__file__).resolve().parents[3]
|
||||
_PACKAGE_DIR = Path(__file__).resolve().parent
|
||||
_OUTPUT_DIR = _ROOT_DIR / "outputs" / "g2pw_cuda_bridge"
|
||||
_WRAPPER_SOURCE = _PACKAGE_DIR / "g2pw_cuda_bridge.cpp"
|
||||
_LOCK_PATH = _OUTPUT_DIR / "build.lock"
|
||||
|
||||
|
||||
def _env_flag(name: str, default: bool) -> int:
|
||||
raw = os.environ.get(name)
|
||||
if raw is None:
|
||||
return 1 if default else 0
|
||||
return 0 if raw.strip().lower() in {"0", "false", "no", "off"} else 1
|
||||
|
||||
|
||||
def _env_int(name: str, default: int) -> int:
|
||||
raw = os.environ.get(name)
|
||||
if raw is None or raw.strip() == "":
|
||||
return int(default)
|
||||
return int(raw)
|
||||
|
||||
|
||||
def _resolve_cuda_root() -> Path:
|
||||
env_root = os.environ.get("GPTSOVITS_G2PW_CUDA_ROOT", "").strip()
|
||||
candidates = [
|
||||
env_root,
|
||||
_ROOT_DIR / "third_party" / "g2pw-cu",
|
||||
]
|
||||
for candidate in candidates:
|
||||
if not candidate:
|
||||
continue
|
||||
path = Path(candidate).expanduser().resolve()
|
||||
if path.exists():
|
||||
return path
|
||||
checked = [
|
||||
str(Path(candidate).expanduser().resolve())
|
||||
for candidate in candidates
|
||||
if str(candidate).strip() != ""
|
||||
]
|
||||
raise G2PWCudaError(
|
||||
"Cannot locate g2pw-cu root. "
|
||||
"Expected one of: "
|
||||
f"{checked}. "
|
||||
"Recommended: clone https://github.com/baicai-1145/g2pw-cu.git into "
|
||||
f"{(_ROOT_DIR / 'third_party' / 'g2pw-cu').as_posix()} "
|
||||
"or set GPTSOVITS_G2PW_CUDA_ROOT explicitly."
|
||||
)
|
||||
|
||||
|
||||
def _resolve_runtime_paths() -> tuple[Path, Path, Path]:
|
||||
cuda_root = _resolve_cuda_root()
|
||||
runtime_lib = Path(
|
||||
os.environ.get("GPTSOVITS_G2PW_CUDA_RUNTIME_LIB", str(cuda_root / "build" / "libg2pw_runtime.so"))
|
||||
).expanduser()
|
||||
manifest_path = Path(
|
||||
os.environ.get("GPTSOVITS_G2PW_CUDA_MANIFEST", str(cuda_root / "artifacts" / "model" / "manifest.txt"))
|
||||
).expanduser()
|
||||
weights_path = Path(
|
||||
os.environ.get("GPTSOVITS_G2PW_CUDA_WEIGHTS", str(cuda_root / "artifacts" / "model" / "weights.bin"))
|
||||
).expanduser()
|
||||
for path in (runtime_lib, manifest_path, weights_path):
|
||||
if not path.exists():
|
||||
raise G2PWCudaError(f"Missing g2pw-cu artifact: {path}")
|
||||
return runtime_lib.resolve(), manifest_path.resolve(), weights_path.resolve()
|
||||
|
||||
|
||||
def _build_bridge(wrapper_output: Path, runtime_lib: Path) -> None:
|
||||
_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
compile_cmd = [
|
||||
os.environ.get("CXX", "g++"),
|
||||
"-O3",
|
||||
"-std=c++17",
|
||||
"-shared",
|
||||
"-fPIC",
|
||||
str(_WRAPPER_SOURCE),
|
||||
"-I",
|
||||
str(runtime_lib.parent.parent / "include"),
|
||||
"-L",
|
||||
str(runtime_lib.parent),
|
||||
"-lg2pw_runtime",
|
||||
f"-Wl,-rpath,{runtime_lib.parent}",
|
||||
"-o",
|
||||
str(wrapper_output),
|
||||
]
|
||||
result = subprocess.run(compile_cmd, capture_output=True, text=True, check=False)
|
||||
if result.returncode != 0:
|
||||
raise G2PWCudaError(
|
||||
"Failed to build g2pw-cu bridge:\n"
|
||||
f"cmd={' '.join(compile_cmd)}\n"
|
||||
f"stdout={result.stdout}\n"
|
||||
f"stderr={result.stderr}"
|
||||
)
|
||||
|
||||
|
||||
def _ensure_bridge_built(runtime_lib: Path) -> Path:
|
||||
wrapper_output = _OUTPUT_DIR / "g2pw_cuda_bridge.so"
|
||||
_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
with _LOCK_PATH.open("w", encoding="utf-8") as lock_file:
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
||||
needs_build = not wrapper_output.exists()
|
||||
if not needs_build:
|
||||
so_mtime = wrapper_output.stat().st_mtime
|
||||
needs_build = so_mtime < _WRAPPER_SOURCE.stat().st_mtime or so_mtime < runtime_lib.stat().st_mtime
|
||||
if needs_build:
|
||||
tmp_output = wrapper_output.with_suffix(".tmp.so")
|
||||
if tmp_output.exists():
|
||||
tmp_output.unlink()
|
||||
_build_bridge(tmp_output, runtime_lib)
|
||||
tmp_output.replace(wrapper_output)
|
||||
return wrapper_output
|
||||
|
||||
|
||||
def _load_bridge():
|
||||
runtime_lib, manifest_path, weights_path = _resolve_runtime_paths()
|
||||
bridge_path = _ensure_bridge_built(runtime_lib)
|
||||
global_mode = getattr(ctypes, "RTLD_GLOBAL", getattr(os, "RTLD_GLOBAL", 0))
|
||||
ctypes.CDLL(str(runtime_lib), mode=global_mode)
|
||||
lib = ctypes.CDLL(str(bridge_path))
|
||||
lib.g2pw_runtime_create.argtypes = [
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
]
|
||||
lib.g2pw_runtime_create.restype = ctypes.c_void_p
|
||||
lib.g2pw_runtime_destroy.argtypes = [ctypes.c_void_p]
|
||||
lib.g2pw_runtime_destroy.restype = None
|
||||
lib.g2pw_runtime_last_error.argtypes = [ctypes.c_void_p]
|
||||
lib.g2pw_runtime_last_error.restype = ctypes.c_char_p
|
||||
lib.g2pw_runtime_num_labels.argtypes = [ctypes.c_void_p]
|
||||
lib.g2pw_runtime_num_labels.restype = ctypes.c_int
|
||||
lib.g2pw_runtime_run.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_int32,
|
||||
ctypes.c_int32,
|
||||
ctypes.c_void_p,
|
||||
]
|
||||
lib.g2pw_runtime_run.restype = ctypes.c_int
|
||||
return lib, manifest_path, weights_path, runtime_lib
|
||||
|
||||
|
||||
def _gemm_precision_value() -> int:
|
||||
precision = os.environ.get("GPTSOVITS_G2PW_CUDA_GEMM_PRECISION", "fp32").strip().lower()
|
||||
if precision == "fp16":
|
||||
return 1
|
||||
if precision == "bf16":
|
||||
return 2
|
||||
return 0
|
||||
|
||||
|
||||
class G2PWRuntimeWrapper:
|
||||
def __init__(self, shard_index: int = 0) -> None:
|
||||
self.lib, self.manifest_path, self.weights_path, self.runtime_lib = _load_bridge()
|
||||
self.shard_index = int(shard_index)
|
||||
self.device_ordinal = _env_int("GPTSOVITS_G2PW_CUDA_DEVICE", 0)
|
||||
self.allow_tensor_cores = _env_flag("GPTSOVITS_G2PW_CUDA_ALLOW_TENSOR_CORES", False)
|
||||
self.use_cublaslt_bias_epilogue = _env_flag("GPTSOVITS_G2PW_CUDA_USE_CUBLASLT_BIAS_EPILOGUE", False)
|
||||
self.enable_profiling = _env_flag("GPTSOVITS_G2PW_CUDA_ENABLE_PROFILE", False)
|
||||
self.enable_cuda_graph = _env_flag("GPTSOVITS_G2PW_CUDA_ENABLE_GRAPH", True)
|
||||
self.dump_graph_cache_stats = _env_flag("GPTSOVITS_G2PW_CUDA_DUMP_GRAPH_CACHE_STATS", False)
|
||||
self.full_graph_cache_limit = _env_int("GPTSOVITS_G2PW_CUDA_FULL_GRAPH_CACHE_LIMIT", 0)
|
||||
self.tail_graph_cache_limit = _env_int("GPTSOVITS_G2PW_CUDA_TAIL_GRAPH_CACHE_LIMIT", 0)
|
||||
self.gemm_precision = _gemm_precision_value()
|
||||
self.lock = threading.Lock()
|
||||
self.handle = None
|
||||
self.max_batch_size = 0
|
||||
self.max_seq_len = 0
|
||||
self.num_labels = 0
|
||||
self.batch_enabled = _env_flag("GPTSOVITS_G2PW_CUDA_BATCHING", True) != 0
|
||||
self.batch_window_s = max(0.0, float(_env_int("GPTSOVITS_G2PW_CUDA_BATCH_WINDOW_MS", 1)) / 1000.0)
|
||||
self.batch_max_requests = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_REQUESTS", 64))
|
||||
self.batch_max_rows = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_ROWS", 96))
|
||||
self.batch_max_tokens = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_TOKENS", 4096))
|
||||
self.batch_condition = threading.Condition()
|
||||
self.pending_tasks: Deque[G2PWBatchTask] = deque()
|
||||
self.batch_total_tasks = 0
|
||||
self.batch_total_batches = 0
|
||||
self.batch_total_rows = 0
|
||||
self.batch_total_queue_wait_ms = 0.0
|
||||
self.batch_queue_wait_peak_ms = 0.0
|
||||
self.batch_total_collect_wait_ms = 0.0
|
||||
self.batch_collect_wait_peak_ms = 0.0
|
||||
self.batch_total_run_ms = 0.0
|
||||
self.batch_run_peak_ms = 0.0
|
||||
self.batch_rows_peak = 0
|
||||
self.batch_requests_peak = 0
|
||||
self.batch_pending_peak = 0
|
||||
self.closed = False
|
||||
self._ensure_capacity(
|
||||
batch_size=max(1, _env_int("GPTSOVITS_G2PW_CUDA_MAX_BATCH_SIZE", 256)),
|
||||
seq_len=max(1, _env_int("GPTSOVITS_G2PW_CUDA_MAX_SEQ_LEN", 128)),
|
||||
)
|
||||
self.batch_worker = None
|
||||
if self.batch_enabled:
|
||||
self.batch_worker = threading.Thread(
|
||||
target=self._batch_loop,
|
||||
name=f"g2pw-cuda-batch-worker-{self.shard_index}",
|
||||
daemon=True,
|
||||
)
|
||||
self.batch_worker.start()
|
||||
|
||||
def _destroy_handle(self) -> None:
|
||||
if self.handle:
|
||||
self.lib.g2pw_runtime_destroy(self.handle)
|
||||
self.handle = None
|
||||
|
||||
def close(self) -> None:
|
||||
with self.batch_condition:
|
||||
self.closed = True
|
||||
self.batch_condition.notify_all()
|
||||
self._destroy_handle()
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _last_error(self) -> str:
|
||||
if not self.handle:
|
||||
return "uninitialized runtime"
|
||||
message = self.lib.g2pw_runtime_last_error(self.handle)
|
||||
return "" if not message else message.decode("utf-8", errors="replace")
|
||||
|
||||
def _create_handle(self, batch_size: int, seq_len: int) -> None:
|
||||
new_handle = self.lib.g2pw_runtime_create(
|
||||
str(self.manifest_path).encode("utf-8"),
|
||||
str(self.weights_path).encode("utf-8"),
|
||||
int(self.device_ordinal),
|
||||
int(batch_size),
|
||||
int(seq_len),
|
||||
int(self.full_graph_cache_limit),
|
||||
int(self.tail_graph_cache_limit),
|
||||
int(self.allow_tensor_cores),
|
||||
int(self.use_cublaslt_bias_epilogue),
|
||||
int(self.enable_profiling),
|
||||
int(self.enable_cuda_graph),
|
||||
int(self.dump_graph_cache_stats),
|
||||
int(self.gemm_precision),
|
||||
)
|
||||
if not new_handle:
|
||||
raise G2PWCudaError("g2pw-cu returned null runtime handle")
|
||||
self.handle = new_handle
|
||||
self.max_batch_size = int(batch_size)
|
||||
self.max_seq_len = int(seq_len)
|
||||
self.num_labels = int(self.lib.g2pw_runtime_num_labels(self.handle))
|
||||
last_error = self._last_error()
|
||||
if self.num_labels <= 0 or last_error:
|
||||
self.close()
|
||||
raise G2PWCudaError(f"Failed to initialize g2pw-cu runtime: {last_error or 'num_labels <= 0'}")
|
||||
|
||||
def _ensure_capacity(self, batch_size: int, seq_len: int) -> None:
|
||||
target_batch = max(1, int(batch_size))
|
||||
target_seq = max(1, int(seq_len))
|
||||
if self.handle and target_batch <= self.max_batch_size and target_seq <= self.max_seq_len:
|
||||
return
|
||||
next_batch = max(target_batch, self.max_batch_size * 2 if self.max_batch_size else 0)
|
||||
next_seq = max(target_seq, self.max_seq_len * 2 if self.max_seq_len else 0)
|
||||
self._destroy_handle()
|
||||
self._create_handle(batch_size=next_batch, seq_len=next_seq)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_model_input(model_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
|
||||
input_ids = np.ascontiguousarray(model_input["input_ids"], dtype=np.int64)
|
||||
token_type_ids = np.ascontiguousarray(model_input["token_type_ids"], dtype=np.int64)
|
||||
attention_masks = np.ascontiguousarray(model_input["attention_masks"], dtype=np.int64)
|
||||
phoneme_masks = np.ascontiguousarray(model_input["phoneme_masks"], dtype=np.float32)
|
||||
char_ids = np.ascontiguousarray(model_input["char_ids"], dtype=np.int64)
|
||||
position_ids = np.ascontiguousarray(model_input["position_ids"], dtype=np.int64)
|
||||
batch_size = int(char_ids.shape[0])
|
||||
if input_ids.shape[0] == 1 and batch_size > 1:
|
||||
input_ids = np.ascontiguousarray(np.repeat(input_ids, batch_size, axis=0), dtype=np.int64)
|
||||
token_type_ids = np.ascontiguousarray(np.repeat(token_type_ids, batch_size, axis=0), dtype=np.int64)
|
||||
attention_masks = np.ascontiguousarray(np.repeat(attention_masks, batch_size, axis=0), dtype=np.int64)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"attention_masks": attention_masks,
|
||||
"phoneme_masks": phoneme_masks,
|
||||
"char_ids": char_ids,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
def _run_direct(self, model_input: Dict[str, np.ndarray]) -> np.ndarray:
|
||||
normalized = self._normalize_model_input(model_input)
|
||||
input_ids = normalized["input_ids"]
|
||||
token_type_ids = normalized["token_type_ids"]
|
||||
attention_masks = normalized["attention_masks"]
|
||||
phoneme_masks = normalized["phoneme_masks"]
|
||||
char_ids = normalized["char_ids"]
|
||||
position_ids = normalized["position_ids"]
|
||||
batch_size = int(char_ids.shape[0])
|
||||
seq_len = int(input_ids.shape[1])
|
||||
probs = np.empty((batch_size, self.num_labels), dtype=np.float32)
|
||||
with self.lock:
|
||||
self._ensure_capacity(batch_size=batch_size, seq_len=seq_len)
|
||||
status = self.lib.g2pw_runtime_run(
|
||||
self.handle,
|
||||
input_ids.ctypes.data_as(ctypes.c_void_p),
|
||||
token_type_ids.ctypes.data_as(ctypes.c_void_p),
|
||||
attention_masks.ctypes.data_as(ctypes.c_void_p),
|
||||
phoneme_masks.ctypes.data_as(ctypes.c_void_p),
|
||||
char_ids.ctypes.data_as(ctypes.c_void_p),
|
||||
position_ids.ctypes.data_as(ctypes.c_void_p),
|
||||
batch_size,
|
||||
seq_len,
|
||||
probs.ctypes.data_as(ctypes.c_void_p),
|
||||
)
|
||||
if int(status) != 0:
|
||||
raise G2PWCudaError(f"g2pw-cu inference failed: {self._last_error()}")
|
||||
return probs
|
||||
|
||||
def _can_append_task(self, tasks: List[G2PWBatchTask], candidate: G2PWBatchTask) -> bool:
|
||||
request_count = len(tasks) + 1
|
||||
if request_count > self.batch_max_requests:
|
||||
return False
|
||||
total_rows = sum(int(item.model_input["char_ids"].shape[0]) for item in tasks) + int(
|
||||
candidate.model_input["char_ids"].shape[0]
|
||||
)
|
||||
if total_rows > self.batch_max_rows:
|
||||
return False
|
||||
total_tokens = sum(
|
||||
int(item.model_input["char_ids"].shape[0]) * int(item.model_input["input_ids"].shape[1]) for item in tasks
|
||||
) + int(candidate.model_input["char_ids"].shape[0]) * int(candidate.model_input["input_ids"].shape[1])
|
||||
return total_tokens <= self.batch_max_tokens
|
||||
|
||||
def _merge_batch_inputs(self, tasks: List[G2PWBatchTask]) -> Tuple[Dict[str, np.ndarray], List[Tuple[int, int]]]:
|
||||
normalized_inputs = [self._normalize_model_input(task.model_input) for task in tasks]
|
||||
total_rows = sum(int(item["char_ids"].shape[0]) for item in normalized_inputs)
|
||||
max_seq_len = max(int(item["input_ids"].shape[1]) for item in normalized_inputs)
|
||||
input_ids = np.zeros((total_rows, max_seq_len), dtype=np.int64)
|
||||
token_type_ids = np.zeros((total_rows, max_seq_len), dtype=np.int64)
|
||||
attention_masks = np.zeros((total_rows, max_seq_len), dtype=np.int64)
|
||||
phoneme_masks = np.zeros((total_rows, normalized_inputs[0]["phoneme_masks"].shape[1]), dtype=np.float32)
|
||||
char_ids = np.zeros((total_rows,), dtype=np.int64)
|
||||
position_ids = np.zeros((total_rows,), dtype=np.int64)
|
||||
slices: List[Tuple[int, int]] = []
|
||||
cursor = 0
|
||||
for item in normalized_inputs:
|
||||
rows = int(item["char_ids"].shape[0])
|
||||
seq_len = int(item["input_ids"].shape[1])
|
||||
next_cursor = cursor + rows
|
||||
input_ids[cursor:next_cursor, :seq_len] = item["input_ids"]
|
||||
token_type_ids[cursor:next_cursor, :seq_len] = item["token_type_ids"]
|
||||
attention_masks[cursor:next_cursor, :seq_len] = item["attention_masks"]
|
||||
phoneme_masks[cursor:next_cursor] = item["phoneme_masks"]
|
||||
char_ids[cursor:next_cursor] = item["char_ids"]
|
||||
position_ids[cursor:next_cursor] = item["position_ids"]
|
||||
slices.append((cursor, next_cursor))
|
||||
cursor = next_cursor
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"attention_masks": attention_masks,
|
||||
"phoneme_masks": phoneme_masks,
|
||||
"char_ids": char_ids,
|
||||
"position_ids": position_ids,
|
||||
}, slices
|
||||
|
||||
def _finish_task(
|
||||
self,
|
||||
task: G2PWBatchTask,
|
||||
output: np.ndarray | None = None,
|
||||
profile: Dict[str, float] | None = None,
|
||||
error: Exception | None = None,
|
||||
) -> None:
|
||||
task.output = output
|
||||
task.profile = dict(profile or {})
|
||||
task.error = error
|
||||
task.done_event.set()
|
||||
|
||||
def _batch_loop(self) -> None:
|
||||
while True:
|
||||
with self.batch_condition:
|
||||
while not self.pending_tasks and not self.closed:
|
||||
self.batch_condition.wait()
|
||||
if self.closed and not self.pending_tasks:
|
||||
return
|
||||
first_task = self.pending_tasks.popleft()
|
||||
batch_tasks = [first_task]
|
||||
collect_started = time.perf_counter()
|
||||
deadline = collect_started + self.batch_window_s
|
||||
while True:
|
||||
if len(batch_tasks) >= self.batch_max_requests:
|
||||
break
|
||||
remaining = deadline - time.perf_counter()
|
||||
if remaining <= 0.0:
|
||||
break
|
||||
if not self.pending_tasks:
|
||||
self.batch_condition.wait(timeout=remaining)
|
||||
continue
|
||||
candidate = self.pending_tasks[0]
|
||||
if not self._can_append_task(batch_tasks, candidate):
|
||||
break
|
||||
batch_tasks.append(self.pending_tasks.popleft())
|
||||
collect_wait_ms = max(0.0, (time.perf_counter() - collect_started) * 1000.0)
|
||||
|
||||
now = time.perf_counter()
|
||||
queue_wait_values = [max(0.0, (now - task.enqueued_at) * 1000.0) for task in batch_tasks]
|
||||
try:
|
||||
merged_input, row_slices = self._merge_batch_inputs(batch_tasks)
|
||||
run_started = time.perf_counter()
|
||||
merged_output = self._run_direct(merged_input)
|
||||
run_ms = max(0.0, (time.perf_counter() - run_started) * 1000.0)
|
||||
for task, (start, end) in zip(batch_tasks, row_slices):
|
||||
task_rows = int(task.model_input["char_ids"].shape[0])
|
||||
task_seq_len = int(task.model_input["input_ids"].shape[1])
|
||||
self._finish_task(
|
||||
task,
|
||||
output=np.ascontiguousarray(merged_output[start:end]),
|
||||
profile={
|
||||
"g2pw_runtime_queue_wait_ms": float(max(0.0, (run_started - task.enqueued_at) * 1000.0)),
|
||||
"g2pw_runtime_collect_wait_ms": float(collect_wait_ms),
|
||||
"g2pw_runtime_run_ms": float(run_ms),
|
||||
"g2pw_runtime_batch_rows": float(sum(int(item.model_input["char_ids"].shape[0]) for item in batch_tasks)),
|
||||
"g2pw_runtime_batch_requests": float(len(batch_tasks)),
|
||||
"g2pw_runtime_task_rows": float(task_rows),
|
||||
"g2pw_runtime_task_seq_len": float(task_seq_len),
|
||||
"g2pw_runtime_shard_index": float(self.shard_index),
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
run_ms = 0.0
|
||||
for task in batch_tasks:
|
||||
self._finish_task(task, error=exc)
|
||||
finally:
|
||||
with self.batch_condition:
|
||||
self.batch_total_batches += 1
|
||||
self.batch_total_tasks += len(batch_tasks)
|
||||
self.batch_total_rows += sum(int(task.model_input["char_ids"].shape[0]) for task in batch_tasks)
|
||||
self.batch_total_queue_wait_ms += float(sum(queue_wait_values))
|
||||
self.batch_queue_wait_peak_ms = max(self.batch_queue_wait_peak_ms, max(queue_wait_values or [0.0]))
|
||||
self.batch_total_collect_wait_ms += float(collect_wait_ms) * float(len(batch_tasks))
|
||||
self.batch_collect_wait_peak_ms = max(self.batch_collect_wait_peak_ms, float(collect_wait_ms))
|
||||
self.batch_total_run_ms += float(run_ms)
|
||||
self.batch_run_peak_ms = max(self.batch_run_peak_ms, float(run_ms))
|
||||
self.batch_rows_peak = max(
|
||||
self.batch_rows_peak, sum(int(task.model_input["char_ids"].shape[0]) for task in batch_tasks)
|
||||
)
|
||||
self.batch_requests_peak = max(self.batch_requests_peak, len(batch_tasks))
|
||||
|
||||
def _submit_batched(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]:
|
||||
task = G2PWBatchTask(model_input=model_input)
|
||||
with self.batch_condition:
|
||||
if self.closed:
|
||||
raise G2PWCudaError("g2pw-cu batch worker already closed")
|
||||
task.enqueued_at = time.perf_counter()
|
||||
self.pending_tasks.append(task)
|
||||
self.batch_pending_peak = max(self.batch_pending_peak, len(self.pending_tasks))
|
||||
self.batch_condition.notify_all()
|
||||
task.done_event.wait()
|
||||
if task.error is not None:
|
||||
raise task.error
|
||||
assert task.output is not None
|
||||
return task.output, dict(task.profile)
|
||||
|
||||
def snapshot(self) -> Dict[str, float | int | bool]:
|
||||
with self.batch_condition:
|
||||
average_tasks_per_batch = (
|
||||
float(self.batch_total_tasks) / float(self.batch_total_batches) if self.batch_total_batches > 0 else 0.0
|
||||
)
|
||||
average_rows_per_batch = (
|
||||
float(self.batch_total_rows) / float(self.batch_total_batches) if self.batch_total_batches > 0 else 0.0
|
||||
)
|
||||
average_queue_wait_ms = (
|
||||
float(self.batch_total_queue_wait_ms) / float(self.batch_total_tasks) if self.batch_total_tasks > 0 else 0.0
|
||||
)
|
||||
average_collect_wait_ms = (
|
||||
float(self.batch_total_collect_wait_ms) / float(self.batch_total_tasks)
|
||||
if self.batch_total_tasks > 0
|
||||
else 0.0
|
||||
)
|
||||
return {
|
||||
"shard_index": int(self.shard_index),
|
||||
"enabled": bool(self.batch_enabled),
|
||||
"window_ms": float(self.batch_window_s * 1000.0),
|
||||
"max_requests": int(self.batch_max_requests),
|
||||
"max_rows": int(self.batch_max_rows),
|
||||
"max_tokens": int(self.batch_max_tokens),
|
||||
"pending": int(len(self.pending_tasks)),
|
||||
"pending_peak": int(self.batch_pending_peak),
|
||||
"total_batches": int(self.batch_total_batches),
|
||||
"total_tasks": int(self.batch_total_tasks),
|
||||
"total_rows": int(self.batch_total_rows),
|
||||
"avg_tasks_per_batch": float(average_tasks_per_batch),
|
||||
"avg_rows_per_batch": float(average_rows_per_batch),
|
||||
"avg_queue_wait_ms": float(average_queue_wait_ms),
|
||||
"queue_wait_peak_ms": float(self.batch_queue_wait_peak_ms),
|
||||
"avg_collect_wait_ms": float(average_collect_wait_ms),
|
||||
"collect_wait_peak_ms": float(self.batch_collect_wait_peak_ms),
|
||||
"run_total_ms": float(self.batch_total_run_ms),
|
||||
"run_peak_ms": float(self.batch_run_peak_ms),
|
||||
"batch_rows_peak": int(self.batch_rows_peak),
|
||||
"batch_requests_peak": int(self.batch_requests_peak),
|
||||
}
|
||||
|
||||
def pending_rows(self) -> int:
|
||||
with self.batch_condition:
|
||||
return int(sum(int(task.model_input["char_ids"].shape[0]) for task in self.pending_tasks))
|
||||
|
||||
def pending_count(self) -> int:
|
||||
with self.batch_condition:
|
||||
return int(len(self.pending_tasks))
|
||||
|
||||
def run_with_profile(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]:
|
||||
if not self.batch_enabled:
|
||||
started = time.perf_counter()
|
||||
output = self._run_direct(model_input)
|
||||
return output, {
|
||||
"g2pw_runtime_queue_wait_ms": 0.0,
|
||||
"g2pw_runtime_collect_wait_ms": 0.0,
|
||||
"g2pw_runtime_run_ms": float((time.perf_counter() - started) * 1000.0),
|
||||
"g2pw_runtime_batch_rows": float(model_input["char_ids"].shape[0]),
|
||||
"g2pw_runtime_batch_requests": 1.0,
|
||||
"g2pw_runtime_task_rows": float(model_input["char_ids"].shape[0]),
|
||||
"g2pw_runtime_task_seq_len": float(model_input["input_ids"].shape[1]),
|
||||
"g2pw_runtime_shard_index": float(self.shard_index),
|
||||
}
|
||||
return self._submit_batched(model_input)
|
||||
|
||||
def run(self, model_input: Dict[str, np.ndarray]) -> np.ndarray:
|
||||
output, _profile = self.run_with_profile(model_input)
|
||||
return output
|
||||
|
||||
|
||||
class G2PWRuntimePool:
|
||||
def __init__(self) -> None:
|
||||
self.worker_count = max(1, _env_int("GPTSOVITS_G2PW_CUDA_WORKERS", 2))
|
||||
self.shards = [G2PWRuntimeWrapper(shard_index=index) for index in range(self.worker_count)]
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def _pick_shard(self) -> G2PWRuntimeWrapper:
|
||||
with self.lock:
|
||||
return min(
|
||||
self.shards,
|
||||
key=lambda shard: (
|
||||
shard.pending_rows(),
|
||||
shard.pending_count(),
|
||||
shard.snapshot().get("avg_queue_wait_ms", 0.0),
|
||||
),
|
||||
)
|
||||
|
||||
def run_with_profile(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]:
|
||||
shard = self._pick_shard()
|
||||
output, profile = shard.run_with_profile(model_input)
|
||||
profile["g2pw_runtime_pool_workers"] = float(self.worker_count)
|
||||
return output, profile
|
||||
|
||||
def run(self, model_input: Dict[str, np.ndarray]) -> np.ndarray:
|
||||
output, _profile = self.run_with_profile(model_input)
|
||||
return output
|
||||
|
||||
def snapshot(self) -> Dict[str, float | int | bool | List[Dict[str, float | int | bool]]]:
|
||||
shard_snapshots = [dict(shard.snapshot()) for shard in self.shards]
|
||||
avg_queue_wait_ms = 0.0
|
||||
total_tasks = 0.0
|
||||
pending = 0
|
||||
pending_peak = 0
|
||||
total_batches = 0
|
||||
total_rows = 0
|
||||
batch_rows_peak = 0
|
||||
batch_requests_peak = 0
|
||||
for snapshot in shard_snapshots:
|
||||
tasks = float(snapshot.get("total_tasks", 0.0))
|
||||
avg_queue_wait_ms += float(snapshot.get("avg_queue_wait_ms", 0.0)) * tasks
|
||||
total_tasks += tasks
|
||||
pending += int(snapshot.get("pending", 0))
|
||||
pending_peak = max(pending_peak, int(snapshot.get("pending_peak", 0)))
|
||||
total_batches += int(snapshot.get("total_batches", 0))
|
||||
total_rows += int(snapshot.get("total_rows", 0))
|
||||
batch_rows_peak = max(batch_rows_peak, int(snapshot.get("batch_rows_peak", 0)))
|
||||
batch_requests_peak = max(batch_requests_peak, int(snapshot.get("batch_requests_peak", 0)))
|
||||
return {
|
||||
"worker_count": int(self.worker_count),
|
||||
"pending": int(pending),
|
||||
"pending_peak": int(pending_peak),
|
||||
"total_batches": int(total_batches),
|
||||
"total_tasks": int(total_tasks),
|
||||
"total_rows": int(total_rows),
|
||||
"avg_queue_wait_ms": float(avg_queue_wait_ms / total_tasks) if total_tasks > 0 else 0.0,
|
||||
"batch_rows_peak": int(batch_rows_peak),
|
||||
"batch_requests_peak": int(batch_requests_peak),
|
||||
"shards": shard_snapshots,
|
||||
}
|
||||
|
||||
|
||||
class G2PWCudaConverter(_G2PWBaseOnnxConverter):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str = "G2PWModel/",
|
||||
style: str = "bopomofo",
|
||||
model_source: str = None,
|
||||
enable_non_tradional_chinese: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
model_dir=model_dir,
|
||||
style=style,
|
||||
model_source=model_source,
|
||||
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
||||
)
|
||||
self.runtime = G2PWRuntimePool()
|
||||
self.backend = "cuda"
|
||||
primary_runtime = self.runtime.shards[0]
|
||||
self.device = f"cuda:{primary_runtime.device_ordinal}"
|
||||
self.checkpoint_path = str(primary_runtime.weights_path)
|
||||
self.providers = ["g2pw-cu"]
|
||||
|
||||
def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]:
|
||||
probs = self.runtime.run(model_input)
|
||||
preds = np.argmax(probs, axis=1).tolist()
|
||||
confidences = probs[np.arange(len(preds)), preds].astype(np.float32, copy=False).tolist()
|
||||
return [self.labels[pred] for pred in preds], confidences
|
||||
|
||||
def _predict_with_profile(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float], Dict[str, float]]:
|
||||
started = time.perf_counter()
|
||||
probs, runtime_profile = self.runtime.run_with_profile(model_input)
|
||||
preds = np.argmax(probs, axis=1).tolist()
|
||||
confidences = probs[np.arange(len(preds)), preds].astype(np.float32, copy=False).tolist()
|
||||
profile = dict(runtime_profile)
|
||||
profile["g2pw_runtime_total_ms"] = float((time.perf_counter() - started) * 1000.0)
|
||||
profile["g2pw_predict_ms"] = float(profile["g2pw_runtime_total_ms"])
|
||||
return [self.labels[pred] for pred in preds], confidences, profile
|
||||
|
||||
def snapshot(self) -> Dict[str, float | int | bool]:
|
||||
return dict(self.runtime.snapshot())
|
||||
@ -8,6 +8,7 @@ from pypinyin.core import Pinyin, Style
|
||||
from pypinyin.seg.simpleseg import simple_seg
|
||||
from pypinyin.converter import UltimateConverter
|
||||
from pypinyin.contrib.tone_convert import to_tone
|
||||
from .cuda_api import G2PWCudaConverter
|
||||
from .onnx_api import G2PWOnnxConverter
|
||||
|
||||
current_file_path = os.path.dirname(__file__)
|
||||
@ -27,12 +28,36 @@ class G2PWPinyin(Pinyin):
|
||||
tone_sandhi=False,
|
||||
**kwargs,
|
||||
):
|
||||
self._g2pw = G2PWOnnxConverter(
|
||||
model_dir=model_dir,
|
||||
style="pinyin",
|
||||
model_source=model_source,
|
||||
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
||||
)
|
||||
backend = os.environ.get("GPTSOVITS_G2PW_BACKEND", "cuda").strip().lower()
|
||||
last_error = None
|
||||
self._g2pw = None
|
||||
if backend in {"cuda", "auto"}:
|
||||
try:
|
||||
self._g2pw = G2PWCudaConverter(
|
||||
model_dir=model_dir,
|
||||
style="pinyin",
|
||||
model_source=model_source,
|
||||
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
||||
)
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
strict_mode = os.environ.get("GPTSOVITS_G2PW_CUDA_STRICT", "0").strip().lower() in {
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
"on",
|
||||
}
|
||||
if backend == "cuda" and strict_mode:
|
||||
raise
|
||||
if self._g2pw is None:
|
||||
self._g2pw = G2PWOnnxConverter(
|
||||
model_dir=model_dir,
|
||||
style="pinyin",
|
||||
model_source=model_source,
|
||||
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
||||
)
|
||||
if last_error is not None:
|
||||
print(f"[g2pw] cuda backend unavailable, fallback to onnx: {last_error}")
|
||||
self._converter = Converter(
|
||||
self._g2pw,
|
||||
v_to_u=v_to_u,
|
||||
|
||||
183
GPT_SoVITS/text/g2pw/g2pw_cuda_bridge.cpp
Normal file
183
GPT_SoVITS/text/g2pw/g2pw_cuda_bridge.cpp
Normal file
@ -0,0 +1,183 @@
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "g2pw/runtime.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct G2PWRuntimeHandle {
|
||||
std::unique_ptr<g2pw::Runtime> runtime;
|
||||
std::string last_error;
|
||||
int num_labels = 0;
|
||||
};
|
||||
|
||||
void SetError(G2PWRuntimeHandle* handle, const g2pw::Status& status) {
|
||||
if (handle == nullptr) {
|
||||
return;
|
||||
}
|
||||
handle->last_error = status.message;
|
||||
}
|
||||
|
||||
g2pw::RuntimeConfig BuildConfig(
|
||||
int device_ordinal,
|
||||
int max_batch_size,
|
||||
int max_seq_len,
|
||||
int full_graph_cache_limit,
|
||||
int tail_graph_cache_limit,
|
||||
int allow_tensor_cores,
|
||||
int use_cublaslt_bias_epilogue,
|
||||
int enable_profiling,
|
||||
int enable_cuda_graph,
|
||||
int dump_graph_cache_stats,
|
||||
int gemm_precision) {
|
||||
g2pw::RuntimeConfig config{};
|
||||
config.device_ordinal = device_ordinal;
|
||||
config.max_batch_size = max_batch_size;
|
||||
config.max_seq_len = max_seq_len;
|
||||
config.full_graph_cache_limit = full_graph_cache_limit;
|
||||
config.tail_graph_cache_limit = tail_graph_cache_limit;
|
||||
config.allow_tensor_cores = allow_tensor_cores != 0;
|
||||
config.use_cublaslt_bias_epilogue = use_cublaslt_bias_epilogue != 0;
|
||||
config.enable_profiling = enable_profiling != 0;
|
||||
config.enable_cuda_graph = enable_cuda_graph != 0;
|
||||
config.dump_graph_cache_stats = dump_graph_cache_stats != 0;
|
||||
switch (gemm_precision) {
|
||||
case 1:
|
||||
config.gemm_precision = g2pw::GemmPrecision::kFp16;
|
||||
break;
|
||||
case 2:
|
||||
config.gemm_precision = g2pw::GemmPrecision::kBf16;
|
||||
break;
|
||||
default:
|
||||
config.gemm_precision = g2pw::GemmPrecision::kFp32;
|
||||
break;
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
|
||||
void* g2pw_runtime_create(
|
||||
const char* manifest_path,
|
||||
const char* binary_path,
|
||||
int device_ordinal,
|
||||
int max_batch_size,
|
||||
int max_seq_len,
|
||||
int full_graph_cache_limit,
|
||||
int tail_graph_cache_limit,
|
||||
int allow_tensor_cores,
|
||||
int use_cublaslt_bias_epilogue,
|
||||
int enable_profiling,
|
||||
int enable_cuda_graph,
|
||||
int dump_graph_cache_stats,
|
||||
int gemm_precision) {
|
||||
auto* handle = new G2PWRuntimeHandle();
|
||||
try {
|
||||
if (manifest_path == nullptr || binary_path == nullptr) {
|
||||
handle->last_error = "manifest_path and binary_path must be non-null";
|
||||
return handle;
|
||||
}
|
||||
g2pw::RuntimeConfig config = BuildConfig(
|
||||
device_ordinal,
|
||||
max_batch_size,
|
||||
max_seq_len,
|
||||
full_graph_cache_limit,
|
||||
tail_graph_cache_limit,
|
||||
allow_tensor_cores,
|
||||
use_cublaslt_bias_epilogue,
|
||||
enable_profiling,
|
||||
enable_cuda_graph,
|
||||
dump_graph_cache_stats,
|
||||
gemm_precision);
|
||||
g2pw::Status status = g2pw::Runtime::Create(
|
||||
config,
|
||||
std::string(manifest_path),
|
||||
std::string(binary_path),
|
||||
&handle->runtime);
|
||||
if (!status.ok()) {
|
||||
SetError(handle, status);
|
||||
return handle;
|
||||
}
|
||||
handle->num_labels = handle->runtime != nullptr ? handle->runtime->weights().manifest().num_labels : 0;
|
||||
handle->last_error.clear();
|
||||
return handle;
|
||||
} catch (const std::exception& exc) {
|
||||
handle->last_error = exc.what();
|
||||
return handle;
|
||||
} catch (...) {
|
||||
handle->last_error = "unknown exception";
|
||||
return handle;
|
||||
}
|
||||
}
|
||||
|
||||
void g2pw_runtime_destroy(void* raw_handle) {
|
||||
auto* handle = static_cast<G2PWRuntimeHandle*>(raw_handle);
|
||||
delete handle;
|
||||
}
|
||||
|
||||
const char* g2pw_runtime_last_error(void* raw_handle) {
|
||||
auto* handle = static_cast<G2PWRuntimeHandle*>(raw_handle);
|
||||
if (handle == nullptr) {
|
||||
return "invalid runtime handle";
|
||||
}
|
||||
return handle->last_error.c_str();
|
||||
}
|
||||
|
||||
int g2pw_runtime_num_labels(void* raw_handle) {
|
||||
auto* handle = static_cast<G2PWRuntimeHandle*>(raw_handle);
|
||||
if (handle == nullptr || handle->runtime == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
return handle->num_labels;
|
||||
}
|
||||
|
||||
int g2pw_runtime_run(
|
||||
void* raw_handle,
|
||||
const std::int64_t* input_ids,
|
||||
const std::int64_t* token_type_ids,
|
||||
const std::int64_t* attention_mask,
|
||||
const float* phoneme_mask,
|
||||
const std::int64_t* char_ids,
|
||||
const std::int64_t* position_ids,
|
||||
std::int32_t batch_size,
|
||||
std::int32_t seq_len,
|
||||
float* probs) {
|
||||
auto* handle = static_cast<G2PWRuntimeHandle*>(raw_handle);
|
||||
if (handle == nullptr || handle->runtime == nullptr) {
|
||||
return static_cast<int>(g2pw::StatusCode::kInvalidArgument);
|
||||
}
|
||||
try {
|
||||
g2pw::InferenceInputs inputs{};
|
||||
inputs.input_ids = input_ids;
|
||||
inputs.token_type_ids = token_type_ids;
|
||||
inputs.attention_mask = attention_mask;
|
||||
inputs.phoneme_mask = phoneme_mask;
|
||||
inputs.char_ids = char_ids;
|
||||
inputs.position_ids = position_ids;
|
||||
inputs.batch_size = batch_size;
|
||||
inputs.seq_len = seq_len;
|
||||
|
||||
g2pw::InferenceOutputs outputs{};
|
||||
outputs.probs = probs;
|
||||
|
||||
const g2pw::Status status = handle->runtime->Run(inputs, outputs);
|
||||
if (!status.ok()) {
|
||||
SetError(handle, status);
|
||||
return static_cast<int>(status.code);
|
||||
}
|
||||
handle->last_error.clear();
|
||||
return static_cast<int>(g2pw::StatusCode::kOk);
|
||||
} catch (const std::exception& exc) {
|
||||
handle->last_error = exc.what();
|
||||
return static_cast<int>(g2pw::StatusCode::kInternalError);
|
||||
} catch (...) {
|
||||
handle->last_error = "unknown exception";
|
||||
return static_cast<int>(g2pw::StatusCode::kInternalError);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
import zipfile
|
||||
from typing import Any, Dict, List, Tuple
|
||||
@ -71,6 +72,23 @@ def _find_first_existing_file(*paths: str) -> str:
|
||||
raise FileNotFoundError(f"Files not found: {paths}")
|
||||
|
||||
|
||||
def _resolve_tokenizer_source(model_source: str | None) -> str:
|
||||
candidate_paths = []
|
||||
if model_source:
|
||||
candidate_paths.append(model_source)
|
||||
repo_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
candidate_paths.extend(
|
||||
[
|
||||
os.path.join(repo_root, "pretrained_models", "g2pw-chinese"),
|
||||
os.path.join(repo_root, "pretrained_models", "chinese-roberta-wwm-ext-large"),
|
||||
]
|
||||
)
|
||||
for candidate in candidate_paths:
|
||||
if candidate and os.path.exists(candidate):
|
||||
return candidate
|
||||
return model_source or "bert-base-chinese"
|
||||
|
||||
|
||||
def download_and_decompress(model_dir: str = "G2PWModel/"):
|
||||
if not os.path.exists(model_dir):
|
||||
parent_directory = os.path.dirname(model_dir)
|
||||
@ -106,9 +124,9 @@ class _G2PWBaseOnnxConverter:
|
||||
self.model_dir = download_and_decompress(model_dir)
|
||||
self.config = load_config(config_path=os.path.join(self.model_dir, "config.py"), use_default=True)
|
||||
|
||||
self.model_source = model_source if model_source else self.config.model_source
|
||||
self.model_source = _resolve_tokenizer_source(model_source if model_source else self.config.model_source)
|
||||
self.enable_opencc = enable_non_tradional_chinese
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source, local_files_only=True)
|
||||
|
||||
polyphonic_chars_path = os.path.join(self.model_dir, "POLYPHONIC_CHARS.txt")
|
||||
monophonic_chars_path = os.path.join(self.model_dir, "MONOPHONIC_CHARS.txt")
|
||||
@ -200,6 +218,10 @@ class _G2PWBaseOnnxConverter:
|
||||
return None
|
||||
|
||||
def __call__(self, sentences: List[str]) -> List[List[str]]:
|
||||
results, _profile = self.predict_sentences_with_profile(sentences)
|
||||
return results
|
||||
|
||||
def predict_sentences_with_profile(self, sentences: List[str]) -> Tuple[List[List[str]], Dict[str, float]]:
|
||||
if isinstance(sentences, str):
|
||||
sentences = [sentences]
|
||||
|
||||
@ -213,7 +235,7 @@ class _G2PWBaseOnnxConverter:
|
||||
|
||||
texts, model_query_ids, result_query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
|
||||
if len(texts) == 0:
|
||||
return partial_results
|
||||
return partial_results, {}
|
||||
|
||||
model_input = prepare_onnx_input(
|
||||
tokenizer=self.tokenizer,
|
||||
@ -229,12 +251,21 @@ class _G2PWBaseOnnxConverter:
|
||||
)
|
||||
|
||||
if not model_input:
|
||||
return partial_results
|
||||
return partial_results, {}
|
||||
|
||||
predict_profile: Dict[str, float] = {}
|
||||
if self.enable_sentence_dedup:
|
||||
preds, _confidences = self._predict_with_sentence_dedup(model_input=model_input, texts=texts)
|
||||
preds, _confidences, predict_profile = self._predict_with_sentence_dedup_profiled(
|
||||
model_input=model_input,
|
||||
texts=texts,
|
||||
)
|
||||
else:
|
||||
preds, _confidences = self._predict(model_input=model_input)
|
||||
if hasattr(self, "_predict_with_profile"):
|
||||
preds, _confidences, predict_profile = self._predict_with_profile(model_input=model_input)
|
||||
else:
|
||||
predict_started = time.perf_counter()
|
||||
preds, _confidences = self._predict(model_input=model_input)
|
||||
predict_profile["g2pw_predict_ms"] = float((time.perf_counter() - predict_started) * 1000.0)
|
||||
|
||||
if self.config.use_char_phoneme:
|
||||
preds = [pred.split(" ")[1] for pred in preds]
|
||||
@ -243,7 +274,7 @@ class _G2PWBaseOnnxConverter:
|
||||
for sent_id, query_id, pred in zip(sent_ids, result_query_ids, preds):
|
||||
results[sent_id][query_id] = self.style_convert_func(pred)
|
||||
|
||||
return results
|
||||
return results, predict_profile
|
||||
|
||||
def _prepare_data(
|
||||
self, sentences: List[str]
|
||||
@ -314,6 +345,52 @@ class _G2PWBaseOnnxConverter:
|
||||
|
||||
return preds, confidences
|
||||
|
||||
def _predict_with_sentence_dedup_profiled(
|
||||
self,
|
||||
model_input: Dict[str, Any],
|
||||
texts: List[str],
|
||||
) -> Tuple[List[str], List[float], Dict[str, float]]:
|
||||
if len(texts) <= 1:
|
||||
if hasattr(self, "_predict_with_profile"):
|
||||
return self._predict_with_profile(model_input=model_input)
|
||||
predict_started = time.perf_counter()
|
||||
preds, confidences = self._predict(model_input=model_input)
|
||||
return preds, confidences, {"g2pw_predict_ms": float((time.perf_counter() - predict_started) * 1000.0)}
|
||||
|
||||
grouped_indices: Dict[str, List[int]] = {}
|
||||
for idx, text in enumerate(texts):
|
||||
grouped_indices.setdefault(text, []).append(idx)
|
||||
|
||||
if all(len(indices) == 1 for indices in grouped_indices.values()):
|
||||
if hasattr(self, "_predict_with_profile"):
|
||||
return self._predict_with_profile(model_input=model_input)
|
||||
predict_started = time.perf_counter()
|
||||
preds, confidences = self._predict(model_input=model_input)
|
||||
return preds, confidences, {"g2pw_predict_ms": float((time.perf_counter() - predict_started) * 1000.0)}
|
||||
|
||||
preds: List[str] = [""] * len(texts)
|
||||
confidences: List[float] = [0.0] * len(texts)
|
||||
merged_profile: Dict[str, float] = {}
|
||||
for indices in grouped_indices.values():
|
||||
group_input = {name: value[indices] for name, value in model_input.items()}
|
||||
if len(indices) > 1:
|
||||
for name in ("input_ids", "token_type_ids", "attention_masks"):
|
||||
group_input[name] = group_input[name][:1]
|
||||
if hasattr(self, "_predict_with_profile"):
|
||||
group_preds, group_confidences, group_profile = self._predict_with_profile(model_input=group_input)
|
||||
for key, value in dict(group_profile or {}).items():
|
||||
merged_profile[key] = float(merged_profile.get(key, 0.0)) + float(value)
|
||||
else:
|
||||
predict_started = time.perf_counter()
|
||||
group_preds, group_confidences = self._predict(model_input=group_input)
|
||||
merged_profile["g2pw_predict_ms"] = float(
|
||||
merged_profile.get("g2pw_predict_ms", 0.0) + (time.perf_counter() - predict_started) * 1000.0
|
||||
)
|
||||
for output_idx, pred, confidence in zip(indices, group_preds, group_confidences):
|
||||
preds[output_idx] = pred
|
||||
confidences[output_idx] = confidence
|
||||
return preds, confidences, merged_profile
|
||||
|
||||
|
||||
class G2PWOnnxConverter(_G2PWBaseOnnxConverter):
|
||||
def __init__(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user