Add g2pw submodule and enhance TTS processing with AsyncStageGate

Introduce a new submodule for g2pw and implement AsyncStageGate in PrepareCoordinator to manage concurrent task inflight limits. Update PrepareTextCpuWorker and PrepareRefSemanticBatchWorker to support asynchronous task submission and completion notifications. Enhance profiling capabilities in TTS to track g2pw processing times, improving overall performance and maintainability of the TTS system.
This commit is contained in:
baicai-1145 2026-03-12 23:03:33 +08:00
parent 6a822b28c3
commit 5cf68a91d3
13 changed files with 965 additions and 122 deletions

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "third_party/g2pw-cu"]
path = third_party/g2pw-cu
url = https://github.com/baicai-1145/g2pw-cu.git

View File

@ -1,4 +1,5 @@
import gc
import asyncio
import concurrent.futures
import math
import os
@ -42,6 +43,7 @@ from TTS_infer_pack.prepare_ref_semantic_batch_worker import (
PrepareRefSemanticBatchWorker,
prepare_prompt_semantic_wav16k,
)
from TTS_infer_pack.prepare_text_cpu_worker import PrepareTextCpuWorker
from sv import SV
resample_transform_dict = {}
@ -454,18 +456,12 @@ class TTS:
self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4")))
self.prepare_bert_batch_worker = None
self.prepare_ref_semantic_batch_worker = None
self.prepare_text_cpu_worker = None
self.prepare_text_cpu_workers = max(
0,
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", "0")),
)
self.prepare_text_cpu_executor = (
concurrent.futures.ThreadPoolExecutor(
max_workers=self.prepare_text_cpu_workers,
thread_name_prefix="prepare-text-cpu",
)
if self.prepare_text_cpu_workers > 0
else None
)
self.prepare_text_cpu_executor = None
self._init_models()
self.refresh_runtime_components()
@ -488,6 +484,7 @@ class TTS:
def refresh_runtime_components(self):
self.prepare_bert_batch_worker = None
self.prepare_ref_semantic_batch_worker = None
self.prepare_text_cpu_worker = None
if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0":
self.prepare_bert_batch_worker = PrepareBertBatchWorker(
bert_model=self.bert_model,
@ -535,6 +532,92 @@ class TTS:
bert_stage_limiter=self.prepare_bert_stage_limiter,
bert_batch_worker=self.prepare_bert_batch_worker,
)
if self.prepare_text_cpu_workers > 0:
self.prepare_text_cpu_worker = PrepareTextCpuWorker(
process_fn=lambda text, language: self.text_preprocessor.preprocess_text_segments(
text,
language,
self.configs.version,
),
worker_count=self.prepare_text_cpu_workers,
max_pending_tasks=int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_PENDING_TASKS", "0")),
admission_poll_ms=int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_ADMISSION_POLL_MS", "1")),
admission_controller=self._build_text_cpu_admission_state,
)
@staticmethod
def _safe_queue_qsize(executor) -> int | None:
if executor is None:
return None
queue = getattr(executor, "_work_queue", None)
if queue is None or not hasattr(queue, "qsize"):
return None
try:
return int(queue.qsize())
except Exception:
return None
def snapshot_prepare_runtime_components(self) -> dict:
return {
"text_cpu": {
"workers": int(self.prepare_text_cpu_workers),
"queue_size": self._safe_queue_qsize(self.prepare_text_cpu_executor),
"enabled": bool(self.prepare_text_cpu_worker is not None or self.prepare_text_cpu_executor is not None),
"worker": (
None if self.prepare_text_cpu_worker is None else dict(self.prepare_text_cpu_worker.snapshot())
),
"admission": self._build_text_cpu_admission_state(),
},
"bert": {
"stage_limiter": dict(self.prepare_bert_stage_limiter.snapshot()),
"batch_worker": (
None if self.prepare_bert_batch_worker is None else dict(self.prepare_bert_batch_worker.snapshot())
),
"batching_enabled": bool(self.prepare_bert_batch_worker is not None),
},
"ref_semantic": {
"stage_limiter": dict(self.prepare_ref_audio_stage_limiter.snapshot()),
"batch_worker": (
None
if self.prepare_ref_semantic_batch_worker is None
else dict(self.prepare_ref_semantic_batch_worker.snapshot())
),
"batching_enabled": bool(self.prepare_ref_semantic_batch_worker is not None),
},
"text_preprocessor": (
None if self.text_preprocessor is None or not hasattr(self.text_preprocessor, "snapshot") else self.text_preprocessor.snapshot()
),
}
def _build_text_cpu_admission_state(self) -> dict:
bert_pending_soft_max = max(
0,
int(
os.environ.get(
"GPTSOVITS_PREPARE_TEXT_CPU_BERT_PENDING_SOFT_MAX",
os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_PENDING_THRESHOLD", "32"),
)
),
)
if self.prepare_bert_batch_worker is None or bert_pending_soft_max <= 0:
return {
"blocked": False,
"reason": "",
"bert_pending": 0,
"bert_active_batch_size": 0,
"bert_pending_soft_max": int(bert_pending_soft_max),
}
bert_state = dict(self.prepare_bert_batch_worker.snapshot())
bert_pending = int(bert_state.get("pending", 0))
bert_active_batch_size = int(bert_state.get("active_batch_size", 0))
blocked = bert_pending >= bert_pending_soft_max
return {
"blocked": bool(blocked),
"reason": ("bert_pending" if blocked else ""),
"bert_pending": int(bert_pending),
"bert_active_batch_size": int(bert_active_batch_size),
"bert_pending_soft_max": int(bert_pending_soft_max),
}
def _init_models(
self,
@ -1040,6 +1123,79 @@ class TTS:
},
}
async def extract_ref_audio_bundle_async(self, ref_audio_path: str):
if self.prepare_ref_semantic_batch_worker is None:
return await asyncio.to_thread(self.extract_ref_audio_bundle, ref_audio_path)
load_start = time.perf_counter()
raw_audio, raw_sr = await asyncio.to_thread(self._load_ref_audio_raw, ref_audio_path)
load_ms = (time.perf_counter() - load_start) * 1000.0
prompt_semantic_task = asyncio.create_task(
self.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr)
)
def _build_ref_spec_profile():
with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats:
ref_spec_start = time.perf_counter()
refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0
return refer_spec, {
"ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]),
"ref_spec_ms": float(ref_spec_ms),
"audio_stage_slots": float(ref_spec_limiter_stats["slots"]),
"audio_stage_inflight_peak": float(ref_spec_limiter_stats["peak_inflight"]),
}
ref_spec_task = asyncio.create_task(asyncio.to_thread(_build_ref_spec_profile))
(prompt_semantic, prompt_semantic_profile), (refer_spec, ref_spec_profile) = await asyncio.gather(
prompt_semantic_task,
ref_spec_task,
)
prompt_semantic_ms = (
float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0))
+ float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0))
+ float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0))
)
audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float(
ref_spec_profile.get("ref_spec_wait_ms", 0.0)
)
audio_stage_slots = max(
float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
float(ref_spec_profile.get("audio_stage_slots", 0.0)),
)
audio_stage_inflight_peak = max(
float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
float(ref_spec_profile.get("audio_stage_inflight_peak", 0.0)),
)
return {
"prompt_semantic": prompt_semantic,
"refer_spec": refer_spec,
"raw_audio": raw_audio,
"raw_sr": raw_sr,
"profile": {
"audio_load_ms": float(load_ms),
"audio_stage_wait_ms": float(audio_stage_wait_ms),
"audio_stage_slots": float(audio_stage_slots),
"audio_stage_inflight_peak": float(audio_stage_inflight_peak),
"prompt_semantic_ms": float(prompt_semantic_ms),
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_ms": float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),
"prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
"prompt_semantic_stage_inflight_peak": float(
prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)
),
"prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)),
"prompt_semantic_batch_samples": float(prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)),
"ref_spec_wait_ms": float(ref_spec_profile.get("ref_spec_wait_ms", 0.0)),
"ref_spec_ms": float(ref_spec_profile.get("ref_spec_ms", 0.0)),
"bundle_total_ms": float(load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_profile.get("ref_spec_ms", 0.0)),
},
}
def extract_text_features(self, text: str, language: str, profile: dict | None = None):
return self.text_preprocessor.segment_and_extract_feature_for_text(
text, language, self.configs.version, profile=profile

View File

@ -118,6 +118,15 @@ class TextPreprocessor:
self.bert_stage_limiter = bert_stage_limiter
self.bert_batch_worker = bert_batch_worker
def snapshot(self) -> Dict[str, object]:
return {
"device": str(self.device),
"bert_stage_limiter": (
None if self.bert_stage_limiter is None else dict(self.bert_stage_limiter.snapshot())
),
"bert_batch_worker": None if self.bert_batch_worker is None else dict(self.bert_batch_worker.snapshot()),
}
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############")
text = self.replace_consecutive_punctuation(text)

View File

@ -24,6 +24,7 @@ class ProfiledResult:
submit_at: float
started_at: float
finished_at: float
profile: Dict[str, float] | None = None
@property
def queue_ms(self) -> float:
@ -48,6 +49,52 @@ class PreparedCpuStage:
target_cpu_profiled: ProfiledResult
class AsyncStageGate:
def __init__(self, max_inflight: int, poll_ms: int = 1):
self.max_inflight = max(0, int(max_inflight))
self.lock = threading.Lock()
self.poll_s = max(0.0005, float(max(1, int(poll_ms))) / 1000.0)
self.inflight = 0
self.peak_inflight = 0
self.total_entered = 0
self.total_wait_ms = 0.0
self.wait_peak_ms = 0.0
async def acquire(self) -> Dict[str, float]:
wait_start = time.perf_counter()
while True:
with self.lock:
if self.max_inflight <= 0 or self.inflight < self.max_inflight:
self.inflight += 1
self.total_entered += 1
wait_ms = max(0.0, (time.perf_counter() - wait_start) * 1000.0)
self.total_wait_ms += float(wait_ms)
self.wait_peak_ms = max(self.wait_peak_ms, float(wait_ms))
self.peak_inflight = max(self.peak_inflight, self.inflight)
return {
"wait_ms": float(wait_ms),
"inflight": float(self.inflight),
"peak_inflight": float(self.peak_inflight),
"max_inflight": float(self.max_inflight),
}
await asyncio.sleep(self.poll_s)
def release(self) -> None:
with self.lock:
self.inflight = max(0, self.inflight - 1)
def snapshot(self) -> Dict[str, float]:
with self.lock:
return {
"max_inflight": float(self.max_inflight),
"inflight": float(self.inflight),
"peak_inflight": float(self.peak_inflight),
"total_entered": float(self.total_entered),
"total_wait_ms": float(self.total_wait_ms),
"wait_peak_ms": float(self.wait_peak_ms),
}
class PrepareCoordinator:
def __init__(self, tts: Any):
self.tts = tts
@ -59,7 +106,8 @@ class PrepareCoordinator:
and os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_DIRECT", "0") != "0"
)
self.max_inflight = max(0, int(os.environ.get("GPTSOVITS_PREPARE_MAX_INFLIGHT", "0")))
self._inflight_semaphore = asyncio.Semaphore(self.max_inflight) if self.max_inflight > 0 else None
gate_poll_ms = int(os.environ.get("GPTSOVITS_PREPARE_GATE_POLL_MS", "1"))
self._inflight_gate = AsyncStageGate(self.max_inflight, poll_ms=gate_poll_ms)
self.text_feature_workers = 0
self.text_feature_executor = None
if not self.use_async_text_feature_path:
@ -81,6 +129,29 @@ class PrepareCoordinator:
max_workers=self.ref_audio_workers,
thread_name_prefix="prepare-ref-audio",
)
text_cpu_gate_default = max(0, int(getattr(tts, "prepare_text_cpu_workers", 0) or 0))
text_feature_gate_default = max(0, int(self.text_feature_workers))
ref_audio_gate_default = max(0, int(self.ref_audio_workers))
self.text_cpu_gate = AsyncStageGate(
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_INFLIGHT", str(text_cpu_gate_default))),
poll_ms=gate_poll_ms,
)
self.text_feature_gate = AsyncStageGate(
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_MAX_INFLIGHT", str(text_feature_gate_default))),
poll_ms=gate_poll_ms,
)
self.ref_audio_gate = AsyncStageGate(
int(os.environ.get("GPTSOVITS_PREPARE_REF_MAX_INFLIGHT", str(ref_audio_gate_default))),
poll_ms=gate_poll_ms,
)
self.ref_load_gate = AsyncStageGate(
int(os.environ.get("GPTSOVITS_PREPARE_REF_LOAD_MAX_INFLIGHT", str(ref_audio_gate_default))),
poll_ms=gate_poll_ms,
)
self.ref_spec_gate = AsyncStageGate(
int(os.environ.get("GPTSOVITS_PREPARE_REF_SPEC_MAX_INFLIGHT", str(ref_audio_gate_default))),
poll_ms=gate_poll_ms,
)
def _mark_enter(self) -> Tuple[int, int]:
with self.lock:
@ -94,15 +165,29 @@ class PrepareCoordinator:
with self.lock:
self.inflight = max(0, self.inflight - 1)
def snapshot(self) -> Dict[str, int]:
def snapshot(self) -> Dict[str, Any]:
with self.lock:
return {
snapshot: Dict[str, Any] = {
"inflight": int(self.inflight),
"peak_inflight": int(self.peak_inflight),
"max_inflight": int(self.max_inflight),
"text_feature_workers": int(self.text_feature_workers),
"ref_audio_workers": int(self.ref_audio_workers),
}
runtime_snapshot_fn = getattr(self.tts, "snapshot_prepare_runtime_components", None)
if callable(runtime_snapshot_fn):
try:
snapshot["prepare_runtime_state"] = runtime_snapshot_fn()
except Exception:
snapshot["prepare_runtime_state"] = None
snapshot["prepare_stage_gates"] = {
"text_cpu": self.text_cpu_gate.snapshot(),
"text_feature": self.text_feature_gate.snapshot(),
"ref_audio": self.ref_audio_gate.snapshot(),
"ref_load": self.ref_load_gate.snapshot(),
"ref_spec": self.ref_spec_gate.snapshot(),
}
return snapshot
@staticmethod
def _run_profiled(fn, submit_at: float, *args) -> ProfiledResult:
@ -119,6 +204,12 @@ class PrepareCoordinator:
def _prepare_text_cpu(self, text: str, language: str):
return self.tts.prepare_text_segments(text, language)
def _load_ref_audio_raw(self, ref_audio_path: str):
return self.tts._load_ref_audio_raw(ref_audio_path)
def _extract_ref_spec_from_raw(self, raw_audio, raw_sr: int):
return self.tts._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
@staticmethod
def _build_empty_text_features_like(reference: PreparedTextFeatures | None = None) -> PreparedTextFeatures:
feature_dim = 1024
@ -155,17 +246,54 @@ class PrepareCoordinator:
return await loop.run_in_executor(executor, self._run_profiled, fn, float(submit_at), *args)
async def _run_text_cpu_stage(self, text: str, language: str) -> ProfiledResult:
await self.text_cpu_gate.acquire()
if text in [None, ""]:
submit_at = time.perf_counter()
return ProfiledResult(result=[], submit_at=submit_at, started_at=submit_at, finished_at=submit_at)
try:
submit_at = time.perf_counter()
return ProfiledResult(result=[], submit_at=submit_at, started_at=submit_at, finished_at=submit_at)
finally:
self.text_cpu_gate.release()
text_cpu_worker = getattr(self.tts, "prepare_text_cpu_worker", None)
executor = getattr(self.tts, "prepare_text_cpu_executor", None)
if executor is None:
submit_at = time.perf_counter()
return self._run_profiled(self._prepare_text_cpu, submit_at, text, language)
return await self._run_on_executor(executor, self._prepare_text_cpu, text, language)
try:
if text_cpu_worker is not None:
submit_at = time.perf_counter()
result, worker_profile = await text_cpu_worker.submit_async(text, language)
started_at = float(
submit_at
+ (
float(worker_profile.get("text_cpu_admission_wait_ms", 0.0))
+ float(worker_profile.get("text_cpu_queue_wait_ms", 0.0))
)
/ 1000.0
)
finished_at = float(started_at + float(worker_profile.get("text_cpu_run_ms", 0.0)) / 1000.0)
return ProfiledResult(
result=result,
submit_at=float(submit_at),
started_at=started_at,
finished_at=finished_at,
profile=dict(worker_profile),
)
if executor is None:
submit_at = time.perf_counter()
return self._run_profiled(self._prepare_text_cpu, submit_at, text, language)
return await self._run_on_executor(executor, self._prepare_text_cpu, text, language)
finally:
self.text_cpu_gate.release()
async def _run_text_feature_stage(self, prepared_segments, language: str, cpu_run_ms: float) -> ProfiledResult:
return await self._run_on_executor(self.text_feature_executor, self._build_text_features, prepared_segments, language, cpu_run_ms)
await self.text_feature_gate.acquire()
try:
return await self._run_on_executor(
self.text_feature_executor,
self._build_text_features,
prepared_segments,
language,
cpu_run_ms,
)
finally:
self.text_feature_gate.release()
@staticmethod
def _estimate_text_feature_run_ms(profile: Dict[str, float]) -> float:
@ -199,16 +327,34 @@ class PrepareCoordinator:
)
return prompt_profiled, target_profiled
await self.text_feature_gate.acquire()
target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)}
submit_at = time.perf_counter()
started_at = float(submit_at)
if prompt_is_empty:
target_result_raw = await self.tts.build_text_features_from_segments_async(
target_segments,
profile=target_profile,
)
prompt_result = self._build_empty_text_features_like(
PreparedTextFeatures(
try:
if prompt_is_empty:
target_result_raw = await self.tts.build_text_features_from_segments_async(
target_segments,
profile=target_profile,
)
prompt_result = self._build_empty_text_features_like(
PreparedTextFeatures(
phones=target_result_raw[0],
bert_features=target_result_raw[1],
norm_text=target_result_raw[2],
profile=target_profile,
total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)),
cpu_preprocess_ms=float(target_cpu_run_ms),
)
)
finished_at = time.perf_counter()
prompt_profiled = ProfiledResult(
result=prompt_result,
submit_at=float(submit_at),
started_at=float(submit_at),
finished_at=float(submit_at),
)
target_result = PreparedTextFeatures(
phones=target_result_raw[0],
bert_features=target_result_raw[1],
norm_text=target_result_raw[2],
@ -216,13 +362,37 @@ class PrepareCoordinator:
total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)),
cpu_preprocess_ms=float(target_cpu_run_ms),
)
target_profiled = ProfiledResult(
result=target_result,
submit_at=float(submit_at),
started_at=started_at,
finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0),
)
if finished_at > target_profiled.finished_at:
target_result.profile["bert_total_ms"] = max(
self._estimate_text_feature_run_ms(target_profile),
(finished_at - submit_at) * 1000.0,
)
else:
target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile)
return prompt_profiled, target_profiled
prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)}
prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async(
prompt_segments,
target_segments,
prompt_profile=prompt_profile,
target_profile=target_profile,
)
finished_at = time.perf_counter()
prompt_profiled = ProfiledResult(
result=prompt_result,
submit_at=float(submit_at),
started_at=float(submit_at),
finished_at=float(submit_at),
prompt_result = PreparedTextFeatures(
phones=prompt_result_raw[0],
bert_features=prompt_result_raw[1],
norm_text=prompt_result_raw[2],
profile=prompt_profile,
total_ms=float(prompt_cpu_run_ms + self._estimate_text_feature_run_ms(prompt_profile)),
cpu_preprocess_ms=float(prompt_cpu_run_ms),
)
target_result = PreparedTextFeatures(
phones=target_result_raw[0],
@ -232,79 +402,152 @@ class PrepareCoordinator:
total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)),
cpu_preprocess_ms=float(target_cpu_run_ms),
)
prompt_profiled = ProfiledResult(
result=prompt_result,
submit_at=float(submit_at),
started_at=started_at,
finished_at=float(submit_at + self._estimate_text_feature_run_ms(prompt_profile) / 1000.0),
)
target_profiled = ProfiledResult(
result=target_result,
submit_at=float(submit_at),
started_at=started_at,
finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0),
)
if finished_at > target_profiled.finished_at:
if finished_at > prompt_profiled.finished_at:
prompt_result.profile["bert_total_ms"] = max(
self._estimate_text_feature_run_ms(prompt_profile),
(finished_at - submit_at) * 1000.0,
)
target_result.profile["bert_total_ms"] = max(
self._estimate_text_feature_run_ms(target_profile),
(finished_at - submit_at) * 1000.0,
)
else:
prompt_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(prompt_profile)
target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile)
return prompt_profiled, target_profiled
prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)}
prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async(
prompt_segments,
target_segments,
prompt_profile=prompt_profile,
target_profile=target_profile,
)
finished_at = time.perf_counter()
prompt_result = PreparedTextFeatures(
phones=prompt_result_raw[0],
bert_features=prompt_result_raw[1],
norm_text=prompt_result_raw[2],
profile=prompt_profile,
total_ms=float(prompt_cpu_run_ms + self._estimate_text_feature_run_ms(prompt_profile)),
cpu_preprocess_ms=float(prompt_cpu_run_ms),
)
target_result = PreparedTextFeatures(
phones=target_result_raw[0],
bert_features=target_result_raw[1],
norm_text=target_result_raw[2],
profile=target_profile,
total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)),
cpu_preprocess_ms=float(target_cpu_run_ms),
)
prompt_profiled = ProfiledResult(
result=prompt_result,
submit_at=float(submit_at),
started_at=started_at,
finished_at=float(submit_at + self._estimate_text_feature_run_ms(prompt_profile) / 1000.0),
)
target_profiled = ProfiledResult(
result=target_result,
submit_at=float(submit_at),
started_at=started_at,
finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0),
)
if finished_at > prompt_profiled.finished_at:
prompt_result.profile["bert_total_ms"] = max(
self._estimate_text_feature_run_ms(prompt_profile),
(finished_at - submit_at) * 1000.0,
)
target_result.profile["bert_total_ms"] = max(
self._estimate_text_feature_run_ms(target_profile),
(finished_at - submit_at) * 1000.0,
)
else:
prompt_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(prompt_profile)
target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile)
return prompt_profiled, target_profiled
finally:
self.text_feature_gate.release()
async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult:
return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path)
if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None:
submit_at = time.perf_counter()
started_at = float(submit_at)
await self.ref_load_gate.acquire()
try:
load_profiled = await self._run_on_executor(self.ref_audio_executor, self._load_ref_audio_raw, ref_audio_path)
finally:
self.ref_load_gate.release()
raw_audio, raw_sr = load_profiled.result
prompt_semantic_task = asyncio.create_task(
self.tts.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr)
)
await self.ref_spec_gate.acquire()
try:
ref_spec_task = asyncio.create_task(
self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr)
)
(prompt_semantic, prompt_semantic_profile), ref_spec_profiled = await asyncio.gather(
prompt_semantic_task,
ref_spec_task,
)
finally:
self.ref_spec_gate.release()
refer_spec = ref_spec_profiled.result
limiter_snapshot = (
self.tts.prepare_ref_audio_stage_limiter.snapshot()
if getattr(self.tts, "prepare_ref_audio_stage_limiter", None) is not None
else {}
)
prompt_semantic_ms = (
float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0))
+ float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0))
+ float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0))
)
audio_stage_wait_ms = (
float(load_profiled.queue_ms)
+ float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0))
+ float(ref_spec_profiled.queue_ms)
)
finished_at = time.perf_counter()
result = {
"prompt_semantic": prompt_semantic,
"refer_spec": refer_spec,
"raw_audio": raw_audio,
"raw_sr": raw_sr,
"profile": {
"audio_load_queue_ms": float(load_profiled.queue_ms),
"audio_load_ms": float(load_profiled.run_ms),
"audio_stage_wait_ms": float(audio_stage_wait_ms),
"audio_stage_slots": float(
max(
float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
float(limiter_snapshot.get("slots", 0.0)),
)
),
"audio_stage_inflight_peak": float(
max(
float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
float(limiter_snapshot.get("peak_inflight", 0.0)),
)
),
"prompt_semantic_ms": float(prompt_semantic_ms),
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_ms": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
),
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),
"prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
"prompt_semantic_stage_inflight_peak": float(
prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)
),
"prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)),
"prompt_semantic_batch_samples": float(
prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)
),
"ref_spec_wait_ms": float(ref_spec_profiled.queue_ms),
"ref_spec_ms": float(ref_spec_profiled.run_ms),
"bundle_total_ms": float(
load_profiled.queue_ms
+ load_profiled.run_ms
+ prompt_semantic_ms
+ ref_spec_profiled.queue_ms
+ ref_spec_profiled.run_ms
),
},
}
return ProfiledResult(
result=result,
submit_at=float(submit_at),
started_at=started_at,
finished_at=float(finished_at),
)
await self.ref_audio_gate.acquire()
try:
if hasattr(self.tts, "extract_ref_audio_bundle_async"):
submit_at = time.perf_counter()
started_at = time.perf_counter()
result = await self.tts.extract_ref_audio_bundle_async(ref_audio_path)
finished_at = time.perf_counter()
return ProfiledResult(
result=result,
submit_at=float(submit_at),
started_at=float(started_at),
finished_at=float(finished_at),
)
return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path)
finally:
self.ref_audio_gate.release()
def _release_split_stage_slot(self) -> None:
self._mark_leave()
if self._inflight_semaphore is not None:
self._inflight_semaphore.release()
self._inflight_gate.release()
async def prepare_cpu_stage_profiled_async(
self,
@ -312,9 +555,11 @@ class PrepareCoordinator:
prepare_submit_at: float,
) -> PreparedCpuStage:
admission_start = time.perf_counter()
if self._inflight_semaphore is not None:
await self._inflight_semaphore.acquire()
prepare_admission_wait_ms = max(0.0, (time.perf_counter() - admission_start) * 1000.0)
admission_stats = await self._inflight_gate.acquire()
prepare_admission_wait_ms = max(
float(admission_stats.get("wait_ms", 0.0)),
(time.perf_counter() - admission_start) * 1000.0,
)
current_inflight, peak_inflight = self._mark_enter()
prepare_start = time.perf_counter()
prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang)
@ -382,10 +627,28 @@ class PrepareCoordinator:
"prompt_text_parallel_future_run_tail_after_target_ms": 0.0,
"prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms,
"prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms,
"prompt_text_cpu_admission_wait_ms": float(
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0)
),
"prompt_text_cpu_backpressure_wait_ms": float(
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0)
),
"prompt_text_cpu_capacity_wait_ms": float(
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0)
),
"prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms,
"prompt_text_feature_run_ms": prompt_feature_profiled.run_ms,
"text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms,
"text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms,
"text_cpu_admission_wait_ms": float(
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0)
),
"text_cpu_backpressure_wait_ms": float(
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0)
),
"text_cpu_capacity_wait_ms": float(
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0)
),
"text_feature_queue_ms": target_feature_profiled.queue_ms,
"text_feature_run_ms": target_feature_profiled.run_ms,
"ref_audio_task_queue_ms": ref_audio_profiled.queue_ms,

View File

@ -1,3 +1,4 @@
import asyncio
import threading
import time
import uuid
@ -51,6 +52,8 @@ class RefSemanticTask:
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
created_at: float = field(default_factory=time.perf_counter)
done_event: threading.Event = field(default_factory=threading.Event)
done_loop: asyncio.AbstractEventLoop | None = None
done_future: asyncio.Future | None = None
result_prompt_semantic: torch.Tensor | None = None
error: Exception | None = None
profile: Dict[str, float] = field(default_factory=dict)
@ -115,6 +118,41 @@ class PrepareRefSemanticBatchWorker:
assert task.result_prompt_semantic is not None
return task.result_prompt_semantic, dict(task.profile)
async def submit_async(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]:
loop = asyncio.get_running_loop()
task = RefSemanticTask(
raw_audio=raw_audio,
raw_sr=int(raw_sr),
done_loop=loop,
done_future=loop.create_future(),
)
with self.condition:
self.pending_tasks.append(task)
self.total_submitted += 1
if len(self.pending_tasks) > self.pending_peak:
self.pending_peak = len(self.pending_tasks)
self.condition.notify_all()
return await task.done_future
@staticmethod
def _resolve_done_future(task: RefSemanticTask) -> None:
if task.done_future is None or task.done_future.done():
return
if task.error is not None:
task.done_future.set_exception(task.error)
return
assert task.result_prompt_semantic is not None
task.done_future.set_result((task.result_prompt_semantic, dict(task.profile)))
def _notify_task_done(self, task: RefSemanticTask) -> None:
task.done_event.set()
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_done_future, task)
except RuntimeError:
pass
def snapshot(self) -> Dict[str, int]:
with self.condition:
return {
@ -247,7 +285,7 @@ class PrepareRefSemanticBatchWorker:
for task in batch:
if task.result_prompt_semantic is not None:
task.profile["prompt_semantic_scatter_ms"] = float(scatter_ms)
task.done_event.set()
self._notify_task_done(task)
def _run_loop(self) -> None:
while True:
@ -257,6 +295,6 @@ class PrepareRefSemanticBatchWorker:
except Exception as exc: # noqa: PERF203
for task in batch:
task.error = exc
task.done_event.set()
self._notify_task_done(task)
finally:
self._finalize_batch(batch)

View File

@ -0,0 +1,215 @@
import asyncio
import threading
import time
import uuid
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable, Deque, Dict, Tuple
@dataclass
class TextCpuTask:
text: str
language: str
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
created_at: float = field(default_factory=time.perf_counter)
enqueued_at: float = 0.0
admission_wait_ms: float = 0.0
backpressure_wait_ms: float = 0.0
capacity_wait_ms: float = 0.0
pending_depth_on_enqueue: int = 0
done_event: threading.Event = field(default_factory=threading.Event)
done_loop: asyncio.AbstractEventLoop | None = None
done_future: asyncio.Future | None = None
result: Any = None
error: Exception | None = None
profile: Dict[str, float] = field(default_factory=dict)
class PrepareTextCpuWorker:
def __init__(
self,
process_fn: Callable[[str, str], Any],
worker_count: int,
max_pending_tasks: int = 0,
admission_poll_ms: int = 1,
admission_controller: Callable[[], Dict[str, float | int | bool]] | None = None,
) -> None:
self.process_fn = process_fn
self.worker_count = max(1, int(worker_count))
self.max_pending_tasks = max(0, int(max_pending_tasks))
self.admission_poll_s = max(0.0005, float(max(1, int(admission_poll_ms))) / 1000.0)
self.admission_controller = admission_controller
self.condition = threading.Condition()
self.pending_tasks: Deque[TextCpuTask] = deque()
self.pending_peak = 0
self.total_submitted = 0
self.total_finished = 0
self.active_workers = 0
self.active_workers_peak = 0
self.admission_wait_total_ms = 0.0
self.admission_wait_peak_ms = 0.0
self.backpressure_wait_total_ms = 0.0
self.backpressure_wait_peak_ms = 0.0
self.capacity_wait_total_ms = 0.0
self.capacity_wait_peak_ms = 0.0
self.backpressure_blocked_total = 0
self.worker_threads = [
threading.Thread(target=self._run_loop, name=f"prepare-text-cpu-worker-{index}", daemon=True)
for index in range(self.worker_count)
]
for thread in self.worker_threads:
thread.start()
def _can_enqueue_locked(self) -> bool:
if self.max_pending_tasks <= 0:
return True
return (len(self.pending_tasks) + self.active_workers) < self.max_pending_tasks
def _get_admission_state(self) -> Dict[str, float | int | bool]:
if self.admission_controller is None:
return {"blocked": False}
try:
state = dict(self.admission_controller() or {})
except Exception:
return {"blocked": False}
state["blocked"] = bool(state.get("blocked", False))
return state
def _record_enqueue_locked(
self,
task: TextCpuTask,
*,
admission_wait_ms: float,
backpressure_wait_ms: float,
capacity_wait_ms: float,
) -> None:
task.admission_wait_ms = float(max(0.0, admission_wait_ms))
task.backpressure_wait_ms = float(max(0.0, backpressure_wait_ms))
task.capacity_wait_ms = float(max(0.0, capacity_wait_ms))
task.enqueued_at = time.perf_counter()
task.pending_depth_on_enqueue = int(len(self.pending_tasks))
self.pending_tasks.append(task)
self.total_submitted += 1
self.admission_wait_total_ms += task.admission_wait_ms
self.admission_wait_peak_ms = max(self.admission_wait_peak_ms, task.admission_wait_ms)
self.backpressure_wait_total_ms += task.backpressure_wait_ms
self.backpressure_wait_peak_ms = max(self.backpressure_wait_peak_ms, task.backpressure_wait_ms)
self.capacity_wait_total_ms += task.capacity_wait_ms
self.capacity_wait_peak_ms = max(self.capacity_wait_peak_ms, task.capacity_wait_ms)
if task.backpressure_wait_ms > 0.0:
self.backpressure_blocked_total += 1
if len(self.pending_tasks) > self.pending_peak:
self.pending_peak = len(self.pending_tasks)
self.condition.notify_all()
async def _enqueue_task_async(self, task: TextCpuTask) -> None:
admission_started = time.perf_counter()
backpressure_wait_ms = 0.0
capacity_wait_ms = 0.0
while True:
loop_start = time.perf_counter()
admission_state = self._get_admission_state()
blocked = bool(admission_state.get("blocked", False))
with self.condition:
if not blocked and self._can_enqueue_locked():
self._record_enqueue_locked(
task,
admission_wait_ms=(time.perf_counter() - admission_started) * 1000.0,
backpressure_wait_ms=backpressure_wait_ms,
capacity_wait_ms=capacity_wait_ms,
)
return
await asyncio.sleep(self.admission_poll_s)
waited_ms = (time.perf_counter() - loop_start) * 1000.0
if blocked:
backpressure_wait_ms += waited_ms
else:
capacity_wait_ms += waited_ms
def submit(self, text: str, language: str) -> Tuple[Any, Dict[str, float]]:
task = TextCpuTask(text=str(text), language=str(language))
asyncio.run(self._enqueue_task_async(task))
task.done_event.wait()
if task.error is not None:
raise task.error
return task.result, dict(task.profile)
async def submit_async(self, text: str, language: str) -> Tuple[Any, Dict[str, float]]:
loop = asyncio.get_running_loop()
task = TextCpuTask(
text=str(text),
language=str(language),
done_loop=loop,
done_future=loop.create_future(),
)
await self._enqueue_task_async(task)
return await task.done_future
@staticmethod
def _resolve_done_future(task: TextCpuTask) -> None:
if task.done_future is None or task.done_future.done():
return
if task.error is not None:
task.done_future.set_exception(task.error)
return
task.done_future.set_result((task.result, dict(task.profile)))
def _notify_task_done(self, task: TextCpuTask) -> None:
task.done_event.set()
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_done_future, task)
except RuntimeError:
pass
def snapshot(self) -> Dict[str, int | float]:
with self.condition:
return {
"worker_count": int(self.worker_count),
"pending": int(len(self.pending_tasks)),
"pending_peak": int(self.pending_peak),
"active_workers": int(self.active_workers),
"active_workers_peak": int(self.active_workers_peak),
"total_submitted": int(self.total_submitted),
"total_finished": int(self.total_finished),
"max_pending_tasks": int(self.max_pending_tasks),
"admission_wait_total_ms": float(self.admission_wait_total_ms),
"admission_wait_peak_ms": float(self.admission_wait_peak_ms),
"backpressure_wait_total_ms": float(self.backpressure_wait_total_ms),
"backpressure_wait_peak_ms": float(self.backpressure_wait_peak_ms),
"capacity_wait_total_ms": float(self.capacity_wait_total_ms),
"capacity_wait_peak_ms": float(self.capacity_wait_peak_ms),
"backpressure_blocked_total": int(self.backpressure_blocked_total),
}
def _run_loop(self) -> None:
while True:
with self.condition:
while not self.pending_tasks:
self.condition.wait()
task = self.pending_tasks.popleft()
self.active_workers += 1
self.active_workers_peak = max(self.active_workers_peak, self.active_workers)
started_at = time.perf_counter()
try:
task.result = self.process_fn(task.text, task.language)
task.profile = {
"text_cpu_admission_wait_ms": float(task.admission_wait_ms),
"text_cpu_backpressure_wait_ms": float(task.backpressure_wait_ms),
"text_cpu_capacity_wait_ms": float(task.capacity_wait_ms),
"text_cpu_queue_wait_ms": max(0.0, (started_at - task.enqueued_at) * 1000.0),
"text_cpu_pending_depth_on_enqueue": float(task.pending_depth_on_enqueue),
"text_cpu_run_ms": max(0.0, (time.perf_counter() - started_at) * 1000.0),
}
except Exception as exc: # noqa: PERF203
task.error = exc
finally:
with self.condition:
self.active_workers = max(0, self.active_workers - 1)
self.total_finished += 1
self.condition.notify_all()
self._notify_task_done(task)

View File

@ -421,6 +421,55 @@ def _iter_contiguous_sampling_groups(
return groups
def _uniform_sampling_group_key(active_batch: T2SActiveBatch) -> Optional[Tuple[int, float, float, float, bool]]:
if not active_batch.states:
return None
if active_batch.step_indices.numel() <= 0:
return None
first_step_index = int(active_batch.step_indices[0].item())
if bool((active_batch.step_indices != first_step_index).any().item()):
return None
first_state = active_batch.states[0]
first_key = _sampling_group_key(
top_k=first_state.top_k,
top_p=first_state.top_p,
temperature=first_state.temperature,
repetition_penalty=first_state.repetition_penalty,
trim_eos=first_step_index < 11,
)
for state in active_batch.states[1:]:
if (
state.top_k != first_state.top_k
or state.top_p != first_state.top_p
or state.temperature != first_state.temperature
or state.repetition_penalty != first_state.repetition_penalty
):
return None
return first_key
def _batched_sample_uniform(
logits: torch.Tensor,
histories: Sequence[torch.LongTensor],
sampling_key: Tuple[int, float, float, float, bool],
) -> Tuple[torch.Tensor, torch.Tensor]:
top_k, top_p, temperature, repetition_penalty, trim_eos = sampling_key
sample_logits = logits[:, :-1] if trim_eos else logits
padded_histories, history_mask = _pad_token_sequences(histories)
probs = logits_to_probs(
logits=sample_logits,
previous_tokens=padded_histories,
previous_token_mask=history_mask,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
)
sampled = multinomial_sample_one_no_sync(probs)
argmax_tokens = torch.argmax(sample_logits, dim=-1)
return sampled, argmax_tokens
def _batched_sample_by_group(
logits: torch.Tensor,
histories: Sequence[torch.LongTensor],
@ -594,27 +643,59 @@ def _sample_per_request(
keep_indices: List[int] = []
updated_sequences: List[torch.LongTensor] = []
sampling_keys = [
_sampling_group_key(
top_k=state.top_k,
top_p=state.top_p,
temperature=state.temperature,
repetition_penalty=state.repetition_penalty,
trim_eos=int(active_batch.step_indices[batch_index].item()) < 11,
uniform_sampling_key = _uniform_sampling_group_key(active_batch)
sampled_items: List[torch.Tensor]
argmax_tokens: List[int]
sampled_token_tensor: Optional[torch.Tensor] = None
argmax_token_tensor: Optional[torch.Tensor] = None
if uniform_sampling_key is not None:
sampled_tensor, argmax_tensor = _batched_sample_uniform(
logits=logits,
histories=active_batch.y_sequences,
sampling_key=uniform_sampling_key,
)
sampled_token_tensor = sampled_tensor.view(-1)
argmax_token_tensor = argmax_tensor.view(-1)
if (
all(state.early_stop_num == -1 for state in active_batch.states)
and int(active_batch.step_indices[0].item()) + 1 < max_steps
and not bool(sampled_token_tensor.eq(model.EOS).any().item())
and not bool(argmax_token_tensor.eq(model.EOS).any().item())
):
return (
[],
list(range(len(active_batch.states))),
[torch.cat([history, sampled_token_tensor[index : index + 1]], dim=0) for index, history in enumerate(active_batch.y_sequences)],
)
sampled_items = [sampled_tensor[index : index + 1] for index in range(sampled_tensor.shape[0])]
argmax_tokens = [int(item) for item in argmax_tensor.tolist()]
else:
sampling_keys = [
_sampling_group_key(
top_k=state.top_k,
top_p=state.top_p,
temperature=state.temperature,
repetition_penalty=state.repetition_penalty,
trim_eos=int(active_batch.step_indices[batch_index].item()) < 11,
)
for batch_index, state in enumerate(active_batch.states)
]
sampled_items, argmax_tokens = _batched_sample_by_group(
logits=logits,
histories=active_batch.y_sequences,
sampling_keys=sampling_keys,
)
for batch_index, state in enumerate(active_batch.states)
]
sampled_items, argmax_tokens = _batched_sample_by_group(
logits=logits,
histories=active_batch.y_sequences,
sampling_keys=sampling_keys,
)
for batch_index, state in enumerate(active_batch.states):
step_index = int(active_batch.step_indices[batch_index].item())
current_history = active_batch.y_sequences[batch_index]
sampled = sampled_items[batch_index]
sampled_token = int(sampled[0, 0].item())
argmax_token = argmax_tokens[batch_index]
if sampled_token_tensor is not None and argmax_token_tensor is not None:
sampled = sampled_token_tensor[batch_index : batch_index + 1]
sampled_token = int(sampled_token_tensor[batch_index].item())
argmax_token = int(argmax_token_tensor[batch_index].item())
else:
sampled = sampled_items[batch_index]
sampled_token = int(sampled[0, 0].item())
argmax_token = argmax_tokens[batch_index]
new_history = torch.cat([current_history, sampled.view(-1)], dim=0)
finish_reason: Optional[str] = None
@ -690,6 +771,13 @@ def decode_one_step(
finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps)
if len(keep_indices) == 0:
return None, finished_items
if len(keep_indices) == len(active_batch.request_ids):
active_batch.y_sequences = updated_sequences
active_batch.step_indices = active_batch.step_indices + 1
if not was_prefill and active_batch.kv_lens is not None:
active_batch.kv_lens = active_batch.kv_lens + 1
active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences)
return active_batch, finished_items
device = logits.device
keep_tensor = torch.LongTensor(keep_indices).to(device)

View File

@ -314,6 +314,7 @@ def build_scheduler_submit_headers(
"X-Prepare-Text-Pair-Wall-Ms": format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)),
"X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))),
"X-Prepare-Engine-GPU-Queue-Wait-Ms": format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)),
"X-Prepare-Engine-GPU-Batch-Size": str(int(prepare_profile.get("engine_gpu_prepare_batch_size", 0.0))),
"X-Prepare-Audio-Load-Ms": format_ms_header(prepare_profile.get("audio_load_ms", 0.0)),
"X-Prepare-Audio-Stage-Wait-Ms": format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)),
"X-Prepare-Prompt-Semantic-Ms": format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)),

View File

@ -44,6 +44,16 @@ class EngineTaskQueueOwner:
return None
return self.queue.popleft()
def pop_left_many(self, max_items: int) -> List[Any]:
limit = max(1, int(max_items))
with self.condition:
if not self.queue:
return []
selected: List[Any] = []
while self.queue and len(selected) < limit:
selected.append(self.queue.popleft())
return selected
def mark_completed(self, count: int = 1, *, notify: bool = False) -> None:
if count <= 0:
return
@ -315,6 +325,7 @@ class EngineGpuPrepareTask:
engine_request_id: str | None
enqueue_time: float
queue_wait_ms: float = 0.0
admission_wait_ms: float = 0.0
error: str | None = None

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import os
import time
from typing import Any
@ -9,6 +10,19 @@ from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineGpuPrepare
class EnginePrepareStageMixin:
async def _wait_prepare_queue_admission(self) -> float:
soft_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_SOFT_MAX", "0")))
if soft_max <= 0:
return 0.0
poll_s = max(
0.0005,
float(max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_ADMISSION_POLL_MS", "1")))) / 1000.0,
)
wait_start = time.perf_counter()
while self.prepare_queue_owner.waiting_count() >= soft_max:
await asyncio.sleep(poll_s)
return max(0.0, (time.perf_counter() - wait_start) * 1000.0)
async def prepare_state_via_engine_gpu_queue(
self,
*,
@ -16,12 +30,14 @@ class EnginePrepareStageMixin:
prepare_submit_at: float,
engine_request_id: str | None,
) -> tuple[T2SRequestState, float, float]:
prepare_queue_admission_wait_ms = await self._wait_prepare_queue_admission()
cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at)
if engine_request_id not in [None, ""]:
self.update_request_state(
str(engine_request_id),
EngineStatus.GPU_PREPARING,
{
"engine_prepare_queue_admission_wait_ms": float(prepare_queue_admission_wait_ms),
"prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms),
"prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms),
"text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms),
@ -37,31 +53,44 @@ class EnginePrepareStageMixin:
done_future=done_future,
engine_request_id=engine_request_id or spec.request_id,
enqueue_time=time.perf_counter(),
admission_wait_ms=float(prepare_queue_admission_wait_ms),
)
self.prepare_queue_owner.enqueue(task)
self.notify_arbiter()
return await done_future
def run_engine_prepare_once(self) -> bool:
task = self.prepare_queue_owner.pop_left()
if task is None:
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
tasks = self.prepare_queue_owner.pop_left_many(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))
if not tasks:
return False
queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0)
try:
state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run(
self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage)
)
now = time.perf_counter()
queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks]
batch_results = asyncio.run(
self.scheduler_worker.prepare_gpu_stages_profiled_async([task.cpu_stage for task in tasks])
)
completed_count = 0
for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results):
if isinstance(result, Exception):
task.error = str(result)
self.fail_request_state(task.engine_request_id or task.request_id, str(result))
self._notify_prepare_error(task, result)
completed_count += 1
continue
state, prepare_exec_started_at, prepare_exec_finished_at = result
state.prepare_profile["engine_prepare_queue_admission_wait_ms"] = float(task.admission_wait_ms)
state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms)
state.prepare_profile["engine_gpu_prepare_batch_size"] = float(len(tasks))
if task.engine_request_id not in [None, ""]:
self.merge_request_state_profile(
str(task.engine_request_id),
{"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)},
{
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms),
"engine_gpu_prepare_batch_size": float(len(tasks)),
},
)
self.prepare_queue_owner.mark_completed(1)
self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at))
return True
except Exception as exc:
task.error = str(exc)
self.fail_request_state(task.engine_request_id or task.request_id, str(exc))
self._notify_prepare_error(task, exc)
return True
completed_count += 1
self.prepare_queue_owner.mark_completed(completed_count)
return True

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import os
import time
from typing import Callable, Dict, List
@ -32,6 +33,11 @@ class WorkerPrepareExecutor:
def get_max_inflight(self) -> int:
return int(self.coordinator.snapshot().get("max_inflight", 0))
def get_batch_policy(self) -> Dict[str, int]:
return {
"prepare_batch_max_items": max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_BATCH_MAX_ITEMS", 8))),
}
def is_idle(self) -> bool:
return int(self.coordinator.snapshot().get("inflight", 0)) <= 0
@ -69,3 +75,17 @@ class WorkerPrepareExecutor:
return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage)
finally:
self._notify_state_change()
async def prepare_gpu_stages_profiled_async(
self,
cpu_stages: List[PreparedCpuStage],
) -> List[tuple[T2SRequestState, float, float] | Exception]:
try:
return list(
await asyncio.gather(
*[self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) for cpu_stage in cpu_stages],
return_exceptions=True,
)
)
finally:
self._notify_state_change()

View File

@ -32,6 +32,9 @@ class WorkerSubmitLifecycleMixin:
def get_finalize_batch_policy(self) -> Dict[str, Any]:
return dict(self.finalize_executor.get_batch_policy())
def get_prepare_batch_policy(self) -> Dict[str, int]:
return dict(self.prepare_executor.get_batch_policy())
def get_decode_runtime_counters(self) -> Dict[str, int]:
with self.condition:
return self.decode_runtime_tracker.get_counters()
@ -258,3 +261,9 @@ class WorkerSubmitLifecycleMixin:
cpu_stage: PreparedCpuStage,
) -> tuple[T2SRequestState, float, float]:
return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage)
async def prepare_gpu_stages_profiled_async(
self,
cpu_stages: List[PreparedCpuStage],
) -> List[tuple[T2SRequestState, float, float] | Exception]:
return await self.prepare_executor.prepare_gpu_stages_profiled_async(cpu_stages)

1
third_party/g2pw-cu vendored Submodule

@ -0,0 +1 @@
Subproject commit a53cf4eed5759f7b5d4563ce6e4b13557e054d98