mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-18 08:18:16 +08:00
Add g2pw submodule and enhance TTS processing with AsyncStageGate
Introduce a new submodule for g2pw and implement AsyncStageGate in PrepareCoordinator to manage concurrent task inflight limits. Update PrepareTextCpuWorker and PrepareRefSemanticBatchWorker to support asynchronous task submission and completion notifications. Enhance profiling capabilities in TTS to track g2pw processing times, improving overall performance and maintainability of the TTS system.
This commit is contained in:
parent
6a822b28c3
commit
5cf68a91d3
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
[submodule "third_party/g2pw-cu"]
|
||||
path = third_party/g2pw-cu
|
||||
url = https://github.com/baicai-1145/g2pw-cu.git
|
||||
@ -1,4 +1,5 @@
|
||||
import gc
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import math
|
||||
import os
|
||||
@ -42,6 +43,7 @@ from TTS_infer_pack.prepare_ref_semantic_batch_worker import (
|
||||
PrepareRefSemanticBatchWorker,
|
||||
prepare_prompt_semantic_wav16k,
|
||||
)
|
||||
from TTS_infer_pack.prepare_text_cpu_worker import PrepareTextCpuWorker
|
||||
from sv import SV
|
||||
|
||||
resample_transform_dict = {}
|
||||
@ -454,18 +456,12 @@ class TTS:
|
||||
self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4")))
|
||||
self.prepare_bert_batch_worker = None
|
||||
self.prepare_ref_semantic_batch_worker = None
|
||||
self.prepare_text_cpu_worker = None
|
||||
self.prepare_text_cpu_workers = max(
|
||||
0,
|
||||
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", "0")),
|
||||
)
|
||||
self.prepare_text_cpu_executor = (
|
||||
concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=self.prepare_text_cpu_workers,
|
||||
thread_name_prefix="prepare-text-cpu",
|
||||
)
|
||||
if self.prepare_text_cpu_workers > 0
|
||||
else None
|
||||
)
|
||||
self.prepare_text_cpu_executor = None
|
||||
|
||||
self._init_models()
|
||||
self.refresh_runtime_components()
|
||||
@ -488,6 +484,7 @@ class TTS:
|
||||
def refresh_runtime_components(self):
|
||||
self.prepare_bert_batch_worker = None
|
||||
self.prepare_ref_semantic_batch_worker = None
|
||||
self.prepare_text_cpu_worker = None
|
||||
if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0":
|
||||
self.prepare_bert_batch_worker = PrepareBertBatchWorker(
|
||||
bert_model=self.bert_model,
|
||||
@ -535,6 +532,92 @@ class TTS:
|
||||
bert_stage_limiter=self.prepare_bert_stage_limiter,
|
||||
bert_batch_worker=self.prepare_bert_batch_worker,
|
||||
)
|
||||
if self.prepare_text_cpu_workers > 0:
|
||||
self.prepare_text_cpu_worker = PrepareTextCpuWorker(
|
||||
process_fn=lambda text, language: self.text_preprocessor.preprocess_text_segments(
|
||||
text,
|
||||
language,
|
||||
self.configs.version,
|
||||
),
|
||||
worker_count=self.prepare_text_cpu_workers,
|
||||
max_pending_tasks=int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_PENDING_TASKS", "0")),
|
||||
admission_poll_ms=int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_ADMISSION_POLL_MS", "1")),
|
||||
admission_controller=self._build_text_cpu_admission_state,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _safe_queue_qsize(executor) -> int | None:
|
||||
if executor is None:
|
||||
return None
|
||||
queue = getattr(executor, "_work_queue", None)
|
||||
if queue is None or not hasattr(queue, "qsize"):
|
||||
return None
|
||||
try:
|
||||
return int(queue.qsize())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def snapshot_prepare_runtime_components(self) -> dict:
|
||||
return {
|
||||
"text_cpu": {
|
||||
"workers": int(self.prepare_text_cpu_workers),
|
||||
"queue_size": self._safe_queue_qsize(self.prepare_text_cpu_executor),
|
||||
"enabled": bool(self.prepare_text_cpu_worker is not None or self.prepare_text_cpu_executor is not None),
|
||||
"worker": (
|
||||
None if self.prepare_text_cpu_worker is None else dict(self.prepare_text_cpu_worker.snapshot())
|
||||
),
|
||||
"admission": self._build_text_cpu_admission_state(),
|
||||
},
|
||||
"bert": {
|
||||
"stage_limiter": dict(self.prepare_bert_stage_limiter.snapshot()),
|
||||
"batch_worker": (
|
||||
None if self.prepare_bert_batch_worker is None else dict(self.prepare_bert_batch_worker.snapshot())
|
||||
),
|
||||
"batching_enabled": bool(self.prepare_bert_batch_worker is not None),
|
||||
},
|
||||
"ref_semantic": {
|
||||
"stage_limiter": dict(self.prepare_ref_audio_stage_limiter.snapshot()),
|
||||
"batch_worker": (
|
||||
None
|
||||
if self.prepare_ref_semantic_batch_worker is None
|
||||
else dict(self.prepare_ref_semantic_batch_worker.snapshot())
|
||||
),
|
||||
"batching_enabled": bool(self.prepare_ref_semantic_batch_worker is not None),
|
||||
},
|
||||
"text_preprocessor": (
|
||||
None if self.text_preprocessor is None or not hasattr(self.text_preprocessor, "snapshot") else self.text_preprocessor.snapshot()
|
||||
),
|
||||
}
|
||||
|
||||
def _build_text_cpu_admission_state(self) -> dict:
|
||||
bert_pending_soft_max = max(
|
||||
0,
|
||||
int(
|
||||
os.environ.get(
|
||||
"GPTSOVITS_PREPARE_TEXT_CPU_BERT_PENDING_SOFT_MAX",
|
||||
os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_PENDING_THRESHOLD", "32"),
|
||||
)
|
||||
),
|
||||
)
|
||||
if self.prepare_bert_batch_worker is None or bert_pending_soft_max <= 0:
|
||||
return {
|
||||
"blocked": False,
|
||||
"reason": "",
|
||||
"bert_pending": 0,
|
||||
"bert_active_batch_size": 0,
|
||||
"bert_pending_soft_max": int(bert_pending_soft_max),
|
||||
}
|
||||
bert_state = dict(self.prepare_bert_batch_worker.snapshot())
|
||||
bert_pending = int(bert_state.get("pending", 0))
|
||||
bert_active_batch_size = int(bert_state.get("active_batch_size", 0))
|
||||
blocked = bert_pending >= bert_pending_soft_max
|
||||
return {
|
||||
"blocked": bool(blocked),
|
||||
"reason": ("bert_pending" if blocked else ""),
|
||||
"bert_pending": int(bert_pending),
|
||||
"bert_active_batch_size": int(bert_active_batch_size),
|
||||
"bert_pending_soft_max": int(bert_pending_soft_max),
|
||||
}
|
||||
|
||||
def _init_models(
|
||||
self,
|
||||
@ -1040,6 +1123,79 @@ class TTS:
|
||||
},
|
||||
}
|
||||
|
||||
async def extract_ref_audio_bundle_async(self, ref_audio_path: str):
|
||||
if self.prepare_ref_semantic_batch_worker is None:
|
||||
return await asyncio.to_thread(self.extract_ref_audio_bundle, ref_audio_path)
|
||||
|
||||
load_start = time.perf_counter()
|
||||
raw_audio, raw_sr = await asyncio.to_thread(self._load_ref_audio_raw, ref_audio_path)
|
||||
load_ms = (time.perf_counter() - load_start) * 1000.0
|
||||
|
||||
prompt_semantic_task = asyncio.create_task(
|
||||
self.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr)
|
||||
)
|
||||
|
||||
def _build_ref_spec_profile():
|
||||
with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats:
|
||||
ref_spec_start = time.perf_counter()
|
||||
refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
|
||||
ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0
|
||||
return refer_spec, {
|
||||
"ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]),
|
||||
"ref_spec_ms": float(ref_spec_ms),
|
||||
"audio_stage_slots": float(ref_spec_limiter_stats["slots"]),
|
||||
"audio_stage_inflight_peak": float(ref_spec_limiter_stats["peak_inflight"]),
|
||||
}
|
||||
|
||||
ref_spec_task = asyncio.create_task(asyncio.to_thread(_build_ref_spec_profile))
|
||||
(prompt_semantic, prompt_semantic_profile), (refer_spec, ref_spec_profile) = await asyncio.gather(
|
||||
prompt_semantic_task,
|
||||
ref_spec_task,
|
||||
)
|
||||
|
||||
prompt_semantic_ms = (
|
||||
float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0))
|
||||
+ float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0))
|
||||
+ float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0))
|
||||
)
|
||||
audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float(
|
||||
ref_spec_profile.get("ref_spec_wait_ms", 0.0)
|
||||
)
|
||||
audio_stage_slots = max(
|
||||
float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||||
float(ref_spec_profile.get("audio_stage_slots", 0.0)),
|
||||
)
|
||||
audio_stage_inflight_peak = max(
|
||||
float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
|
||||
float(ref_spec_profile.get("audio_stage_inflight_peak", 0.0)),
|
||||
)
|
||||
return {
|
||||
"prompt_semantic": prompt_semantic,
|
||||
"refer_spec": refer_spec,
|
||||
"raw_audio": raw_audio,
|
||||
"raw_sr": raw_sr,
|
||||
"profile": {
|
||||
"audio_load_ms": float(load_ms),
|
||||
"audio_stage_wait_ms": float(audio_stage_wait_ms),
|
||||
"audio_stage_slots": float(audio_stage_slots),
|
||||
"audio_stage_inflight_peak": float(audio_stage_inflight_peak),
|
||||
"prompt_semantic_ms": float(prompt_semantic_ms),
|
||||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_ms": float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
|
||||
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
|
||||
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
||||
"prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||||
"prompt_semantic_stage_inflight_peak": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)
|
||||
),
|
||||
"prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)),
|
||||
"prompt_semantic_batch_samples": float(prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)),
|
||||
"ref_spec_wait_ms": float(ref_spec_profile.get("ref_spec_wait_ms", 0.0)),
|
||||
"ref_spec_ms": float(ref_spec_profile.get("ref_spec_ms", 0.0)),
|
||||
"bundle_total_ms": float(load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_profile.get("ref_spec_ms", 0.0)),
|
||||
},
|
||||
}
|
||||
|
||||
def extract_text_features(self, text: str, language: str, profile: dict | None = None):
|
||||
return self.text_preprocessor.segment_and_extract_feature_for_text(
|
||||
text, language, self.configs.version, profile=profile
|
||||
|
||||
@ -118,6 +118,15 @@ class TextPreprocessor:
|
||||
self.bert_stage_limiter = bert_stage_limiter
|
||||
self.bert_batch_worker = bert_batch_worker
|
||||
|
||||
def snapshot(self) -> Dict[str, object]:
|
||||
return {
|
||||
"device": str(self.device),
|
||||
"bert_stage_limiter": (
|
||||
None if self.bert_stage_limiter is None else dict(self.bert_stage_limiter.snapshot())
|
||||
),
|
||||
"bert_batch_worker": None if self.bert_batch_worker is None else dict(self.bert_batch_worker.snapshot()),
|
||||
}
|
||||
|
||||
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||
print(f"############ {i18n('切分文本')} ############")
|
||||
text = self.replace_consecutive_punctuation(text)
|
||||
|
||||
@ -24,6 +24,7 @@ class ProfiledResult:
|
||||
submit_at: float
|
||||
started_at: float
|
||||
finished_at: float
|
||||
profile: Dict[str, float] | None = None
|
||||
|
||||
@property
|
||||
def queue_ms(self) -> float:
|
||||
@ -48,6 +49,52 @@ class PreparedCpuStage:
|
||||
target_cpu_profiled: ProfiledResult
|
||||
|
||||
|
||||
class AsyncStageGate:
|
||||
def __init__(self, max_inflight: int, poll_ms: int = 1):
|
||||
self.max_inflight = max(0, int(max_inflight))
|
||||
self.lock = threading.Lock()
|
||||
self.poll_s = max(0.0005, float(max(1, int(poll_ms))) / 1000.0)
|
||||
self.inflight = 0
|
||||
self.peak_inflight = 0
|
||||
self.total_entered = 0
|
||||
self.total_wait_ms = 0.0
|
||||
self.wait_peak_ms = 0.0
|
||||
|
||||
async def acquire(self) -> Dict[str, float]:
|
||||
wait_start = time.perf_counter()
|
||||
while True:
|
||||
with self.lock:
|
||||
if self.max_inflight <= 0 or self.inflight < self.max_inflight:
|
||||
self.inflight += 1
|
||||
self.total_entered += 1
|
||||
wait_ms = max(0.0, (time.perf_counter() - wait_start) * 1000.0)
|
||||
self.total_wait_ms += float(wait_ms)
|
||||
self.wait_peak_ms = max(self.wait_peak_ms, float(wait_ms))
|
||||
self.peak_inflight = max(self.peak_inflight, self.inflight)
|
||||
return {
|
||||
"wait_ms": float(wait_ms),
|
||||
"inflight": float(self.inflight),
|
||||
"peak_inflight": float(self.peak_inflight),
|
||||
"max_inflight": float(self.max_inflight),
|
||||
}
|
||||
await asyncio.sleep(self.poll_s)
|
||||
|
||||
def release(self) -> None:
|
||||
with self.lock:
|
||||
self.inflight = max(0, self.inflight - 1)
|
||||
|
||||
def snapshot(self) -> Dict[str, float]:
|
||||
with self.lock:
|
||||
return {
|
||||
"max_inflight": float(self.max_inflight),
|
||||
"inflight": float(self.inflight),
|
||||
"peak_inflight": float(self.peak_inflight),
|
||||
"total_entered": float(self.total_entered),
|
||||
"total_wait_ms": float(self.total_wait_ms),
|
||||
"wait_peak_ms": float(self.wait_peak_ms),
|
||||
}
|
||||
|
||||
|
||||
class PrepareCoordinator:
|
||||
def __init__(self, tts: Any):
|
||||
self.tts = tts
|
||||
@ -59,7 +106,8 @@ class PrepareCoordinator:
|
||||
and os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_DIRECT", "0") != "0"
|
||||
)
|
||||
self.max_inflight = max(0, int(os.environ.get("GPTSOVITS_PREPARE_MAX_INFLIGHT", "0")))
|
||||
self._inflight_semaphore = asyncio.Semaphore(self.max_inflight) if self.max_inflight > 0 else None
|
||||
gate_poll_ms = int(os.environ.get("GPTSOVITS_PREPARE_GATE_POLL_MS", "1"))
|
||||
self._inflight_gate = AsyncStageGate(self.max_inflight, poll_ms=gate_poll_ms)
|
||||
self.text_feature_workers = 0
|
||||
self.text_feature_executor = None
|
||||
if not self.use_async_text_feature_path:
|
||||
@ -81,6 +129,29 @@ class PrepareCoordinator:
|
||||
max_workers=self.ref_audio_workers,
|
||||
thread_name_prefix="prepare-ref-audio",
|
||||
)
|
||||
text_cpu_gate_default = max(0, int(getattr(tts, "prepare_text_cpu_workers", 0) or 0))
|
||||
text_feature_gate_default = max(0, int(self.text_feature_workers))
|
||||
ref_audio_gate_default = max(0, int(self.ref_audio_workers))
|
||||
self.text_cpu_gate = AsyncStageGate(
|
||||
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_INFLIGHT", str(text_cpu_gate_default))),
|
||||
poll_ms=gate_poll_ms,
|
||||
)
|
||||
self.text_feature_gate = AsyncStageGate(
|
||||
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_MAX_INFLIGHT", str(text_feature_gate_default))),
|
||||
poll_ms=gate_poll_ms,
|
||||
)
|
||||
self.ref_audio_gate = AsyncStageGate(
|
||||
int(os.environ.get("GPTSOVITS_PREPARE_REF_MAX_INFLIGHT", str(ref_audio_gate_default))),
|
||||
poll_ms=gate_poll_ms,
|
||||
)
|
||||
self.ref_load_gate = AsyncStageGate(
|
||||
int(os.environ.get("GPTSOVITS_PREPARE_REF_LOAD_MAX_INFLIGHT", str(ref_audio_gate_default))),
|
||||
poll_ms=gate_poll_ms,
|
||||
)
|
||||
self.ref_spec_gate = AsyncStageGate(
|
||||
int(os.environ.get("GPTSOVITS_PREPARE_REF_SPEC_MAX_INFLIGHT", str(ref_audio_gate_default))),
|
||||
poll_ms=gate_poll_ms,
|
||||
)
|
||||
|
||||
def _mark_enter(self) -> Tuple[int, int]:
|
||||
with self.lock:
|
||||
@ -94,15 +165,29 @@ class PrepareCoordinator:
|
||||
with self.lock:
|
||||
self.inflight = max(0, self.inflight - 1)
|
||||
|
||||
def snapshot(self) -> Dict[str, int]:
|
||||
def snapshot(self) -> Dict[str, Any]:
|
||||
with self.lock:
|
||||
return {
|
||||
snapshot: Dict[str, Any] = {
|
||||
"inflight": int(self.inflight),
|
||||
"peak_inflight": int(self.peak_inflight),
|
||||
"max_inflight": int(self.max_inflight),
|
||||
"text_feature_workers": int(self.text_feature_workers),
|
||||
"ref_audio_workers": int(self.ref_audio_workers),
|
||||
}
|
||||
runtime_snapshot_fn = getattr(self.tts, "snapshot_prepare_runtime_components", None)
|
||||
if callable(runtime_snapshot_fn):
|
||||
try:
|
||||
snapshot["prepare_runtime_state"] = runtime_snapshot_fn()
|
||||
except Exception:
|
||||
snapshot["prepare_runtime_state"] = None
|
||||
snapshot["prepare_stage_gates"] = {
|
||||
"text_cpu": self.text_cpu_gate.snapshot(),
|
||||
"text_feature": self.text_feature_gate.snapshot(),
|
||||
"ref_audio": self.ref_audio_gate.snapshot(),
|
||||
"ref_load": self.ref_load_gate.snapshot(),
|
||||
"ref_spec": self.ref_spec_gate.snapshot(),
|
||||
}
|
||||
return snapshot
|
||||
|
||||
@staticmethod
|
||||
def _run_profiled(fn, submit_at: float, *args) -> ProfiledResult:
|
||||
@ -119,6 +204,12 @@ class PrepareCoordinator:
|
||||
def _prepare_text_cpu(self, text: str, language: str):
|
||||
return self.tts.prepare_text_segments(text, language)
|
||||
|
||||
def _load_ref_audio_raw(self, ref_audio_path: str):
|
||||
return self.tts._load_ref_audio_raw(ref_audio_path)
|
||||
|
||||
def _extract_ref_spec_from_raw(self, raw_audio, raw_sr: int):
|
||||
return self.tts._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
|
||||
|
||||
@staticmethod
|
||||
def _build_empty_text_features_like(reference: PreparedTextFeatures | None = None) -> PreparedTextFeatures:
|
||||
feature_dim = 1024
|
||||
@ -155,17 +246,54 @@ class PrepareCoordinator:
|
||||
return await loop.run_in_executor(executor, self._run_profiled, fn, float(submit_at), *args)
|
||||
|
||||
async def _run_text_cpu_stage(self, text: str, language: str) -> ProfiledResult:
|
||||
await self.text_cpu_gate.acquire()
|
||||
if text in [None, ""]:
|
||||
submit_at = time.perf_counter()
|
||||
return ProfiledResult(result=[], submit_at=submit_at, started_at=submit_at, finished_at=submit_at)
|
||||
try:
|
||||
submit_at = time.perf_counter()
|
||||
return ProfiledResult(result=[], submit_at=submit_at, started_at=submit_at, finished_at=submit_at)
|
||||
finally:
|
||||
self.text_cpu_gate.release()
|
||||
text_cpu_worker = getattr(self.tts, "prepare_text_cpu_worker", None)
|
||||
executor = getattr(self.tts, "prepare_text_cpu_executor", None)
|
||||
if executor is None:
|
||||
submit_at = time.perf_counter()
|
||||
return self._run_profiled(self._prepare_text_cpu, submit_at, text, language)
|
||||
return await self._run_on_executor(executor, self._prepare_text_cpu, text, language)
|
||||
try:
|
||||
if text_cpu_worker is not None:
|
||||
submit_at = time.perf_counter()
|
||||
result, worker_profile = await text_cpu_worker.submit_async(text, language)
|
||||
started_at = float(
|
||||
submit_at
|
||||
+ (
|
||||
float(worker_profile.get("text_cpu_admission_wait_ms", 0.0))
|
||||
+ float(worker_profile.get("text_cpu_queue_wait_ms", 0.0))
|
||||
)
|
||||
/ 1000.0
|
||||
)
|
||||
finished_at = float(started_at + float(worker_profile.get("text_cpu_run_ms", 0.0)) / 1000.0)
|
||||
return ProfiledResult(
|
||||
result=result,
|
||||
submit_at=float(submit_at),
|
||||
started_at=started_at,
|
||||
finished_at=finished_at,
|
||||
profile=dict(worker_profile),
|
||||
)
|
||||
if executor is None:
|
||||
submit_at = time.perf_counter()
|
||||
return self._run_profiled(self._prepare_text_cpu, submit_at, text, language)
|
||||
return await self._run_on_executor(executor, self._prepare_text_cpu, text, language)
|
||||
finally:
|
||||
self.text_cpu_gate.release()
|
||||
|
||||
async def _run_text_feature_stage(self, prepared_segments, language: str, cpu_run_ms: float) -> ProfiledResult:
|
||||
return await self._run_on_executor(self.text_feature_executor, self._build_text_features, prepared_segments, language, cpu_run_ms)
|
||||
await self.text_feature_gate.acquire()
|
||||
try:
|
||||
return await self._run_on_executor(
|
||||
self.text_feature_executor,
|
||||
self._build_text_features,
|
||||
prepared_segments,
|
||||
language,
|
||||
cpu_run_ms,
|
||||
)
|
||||
finally:
|
||||
self.text_feature_gate.release()
|
||||
|
||||
@staticmethod
|
||||
def _estimate_text_feature_run_ms(profile: Dict[str, float]) -> float:
|
||||
@ -199,16 +327,34 @@ class PrepareCoordinator:
|
||||
)
|
||||
return prompt_profiled, target_profiled
|
||||
|
||||
await self.text_feature_gate.acquire()
|
||||
target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)}
|
||||
submit_at = time.perf_counter()
|
||||
started_at = float(submit_at)
|
||||
if prompt_is_empty:
|
||||
target_result_raw = await self.tts.build_text_features_from_segments_async(
|
||||
target_segments,
|
||||
profile=target_profile,
|
||||
)
|
||||
prompt_result = self._build_empty_text_features_like(
|
||||
PreparedTextFeatures(
|
||||
try:
|
||||
if prompt_is_empty:
|
||||
target_result_raw = await self.tts.build_text_features_from_segments_async(
|
||||
target_segments,
|
||||
profile=target_profile,
|
||||
)
|
||||
prompt_result = self._build_empty_text_features_like(
|
||||
PreparedTextFeatures(
|
||||
phones=target_result_raw[0],
|
||||
bert_features=target_result_raw[1],
|
||||
norm_text=target_result_raw[2],
|
||||
profile=target_profile,
|
||||
total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)),
|
||||
cpu_preprocess_ms=float(target_cpu_run_ms),
|
||||
)
|
||||
)
|
||||
finished_at = time.perf_counter()
|
||||
prompt_profiled = ProfiledResult(
|
||||
result=prompt_result,
|
||||
submit_at=float(submit_at),
|
||||
started_at=float(submit_at),
|
||||
finished_at=float(submit_at),
|
||||
)
|
||||
target_result = PreparedTextFeatures(
|
||||
phones=target_result_raw[0],
|
||||
bert_features=target_result_raw[1],
|
||||
norm_text=target_result_raw[2],
|
||||
@ -216,13 +362,37 @@ class PrepareCoordinator:
|
||||
total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)),
|
||||
cpu_preprocess_ms=float(target_cpu_run_ms),
|
||||
)
|
||||
target_profiled = ProfiledResult(
|
||||
result=target_result,
|
||||
submit_at=float(submit_at),
|
||||
started_at=started_at,
|
||||
finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0),
|
||||
)
|
||||
if finished_at > target_profiled.finished_at:
|
||||
target_result.profile["bert_total_ms"] = max(
|
||||
self._estimate_text_feature_run_ms(target_profile),
|
||||
(finished_at - submit_at) * 1000.0,
|
||||
)
|
||||
else:
|
||||
target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile)
|
||||
return prompt_profiled, target_profiled
|
||||
|
||||
prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)}
|
||||
prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async(
|
||||
prompt_segments,
|
||||
target_segments,
|
||||
prompt_profile=prompt_profile,
|
||||
target_profile=target_profile,
|
||||
)
|
||||
finished_at = time.perf_counter()
|
||||
prompt_profiled = ProfiledResult(
|
||||
result=prompt_result,
|
||||
submit_at=float(submit_at),
|
||||
started_at=float(submit_at),
|
||||
finished_at=float(submit_at),
|
||||
|
||||
prompt_result = PreparedTextFeatures(
|
||||
phones=prompt_result_raw[0],
|
||||
bert_features=prompt_result_raw[1],
|
||||
norm_text=prompt_result_raw[2],
|
||||
profile=prompt_profile,
|
||||
total_ms=float(prompt_cpu_run_ms + self._estimate_text_feature_run_ms(prompt_profile)),
|
||||
cpu_preprocess_ms=float(prompt_cpu_run_ms),
|
||||
)
|
||||
target_result = PreparedTextFeatures(
|
||||
phones=target_result_raw[0],
|
||||
@ -232,79 +402,152 @@ class PrepareCoordinator:
|
||||
total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)),
|
||||
cpu_preprocess_ms=float(target_cpu_run_ms),
|
||||
)
|
||||
prompt_profiled = ProfiledResult(
|
||||
result=prompt_result,
|
||||
submit_at=float(submit_at),
|
||||
started_at=started_at,
|
||||
finished_at=float(submit_at + self._estimate_text_feature_run_ms(prompt_profile) / 1000.0),
|
||||
)
|
||||
target_profiled = ProfiledResult(
|
||||
result=target_result,
|
||||
submit_at=float(submit_at),
|
||||
started_at=started_at,
|
||||
finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0),
|
||||
)
|
||||
if finished_at > target_profiled.finished_at:
|
||||
if finished_at > prompt_profiled.finished_at:
|
||||
prompt_result.profile["bert_total_ms"] = max(
|
||||
self._estimate_text_feature_run_ms(prompt_profile),
|
||||
(finished_at - submit_at) * 1000.0,
|
||||
)
|
||||
target_result.profile["bert_total_ms"] = max(
|
||||
self._estimate_text_feature_run_ms(target_profile),
|
||||
(finished_at - submit_at) * 1000.0,
|
||||
)
|
||||
else:
|
||||
prompt_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(prompt_profile)
|
||||
target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile)
|
||||
return prompt_profiled, target_profiled
|
||||
|
||||
prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)}
|
||||
prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async(
|
||||
prompt_segments,
|
||||
target_segments,
|
||||
prompt_profile=prompt_profile,
|
||||
target_profile=target_profile,
|
||||
)
|
||||
finished_at = time.perf_counter()
|
||||
|
||||
prompt_result = PreparedTextFeatures(
|
||||
phones=prompt_result_raw[0],
|
||||
bert_features=prompt_result_raw[1],
|
||||
norm_text=prompt_result_raw[2],
|
||||
profile=prompt_profile,
|
||||
total_ms=float(prompt_cpu_run_ms + self._estimate_text_feature_run_ms(prompt_profile)),
|
||||
cpu_preprocess_ms=float(prompt_cpu_run_ms),
|
||||
)
|
||||
target_result = PreparedTextFeatures(
|
||||
phones=target_result_raw[0],
|
||||
bert_features=target_result_raw[1],
|
||||
norm_text=target_result_raw[2],
|
||||
profile=target_profile,
|
||||
total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)),
|
||||
cpu_preprocess_ms=float(target_cpu_run_ms),
|
||||
)
|
||||
prompt_profiled = ProfiledResult(
|
||||
result=prompt_result,
|
||||
submit_at=float(submit_at),
|
||||
started_at=started_at,
|
||||
finished_at=float(submit_at + self._estimate_text_feature_run_ms(prompt_profile) / 1000.0),
|
||||
)
|
||||
target_profiled = ProfiledResult(
|
||||
result=target_result,
|
||||
submit_at=float(submit_at),
|
||||
started_at=started_at,
|
||||
finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0),
|
||||
)
|
||||
if finished_at > prompt_profiled.finished_at:
|
||||
prompt_result.profile["bert_total_ms"] = max(
|
||||
self._estimate_text_feature_run_ms(prompt_profile),
|
||||
(finished_at - submit_at) * 1000.0,
|
||||
)
|
||||
target_result.profile["bert_total_ms"] = max(
|
||||
self._estimate_text_feature_run_ms(target_profile),
|
||||
(finished_at - submit_at) * 1000.0,
|
||||
)
|
||||
else:
|
||||
prompt_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(prompt_profile)
|
||||
target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile)
|
||||
return prompt_profiled, target_profiled
|
||||
finally:
|
||||
self.text_feature_gate.release()
|
||||
|
||||
async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult:
|
||||
return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path)
|
||||
if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None:
|
||||
submit_at = time.perf_counter()
|
||||
started_at = float(submit_at)
|
||||
|
||||
await self.ref_load_gate.acquire()
|
||||
try:
|
||||
load_profiled = await self._run_on_executor(self.ref_audio_executor, self._load_ref_audio_raw, ref_audio_path)
|
||||
finally:
|
||||
self.ref_load_gate.release()
|
||||
|
||||
raw_audio, raw_sr = load_profiled.result
|
||||
prompt_semantic_task = asyncio.create_task(
|
||||
self.tts.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr)
|
||||
)
|
||||
await self.ref_spec_gate.acquire()
|
||||
try:
|
||||
ref_spec_task = asyncio.create_task(
|
||||
self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr)
|
||||
)
|
||||
(prompt_semantic, prompt_semantic_profile), ref_spec_profiled = await asyncio.gather(
|
||||
prompt_semantic_task,
|
||||
ref_spec_task,
|
||||
)
|
||||
finally:
|
||||
self.ref_spec_gate.release()
|
||||
|
||||
refer_spec = ref_spec_profiled.result
|
||||
limiter_snapshot = (
|
||||
self.tts.prepare_ref_audio_stage_limiter.snapshot()
|
||||
if getattr(self.tts, "prepare_ref_audio_stage_limiter", None) is not None
|
||||
else {}
|
||||
)
|
||||
prompt_semantic_ms = (
|
||||
float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0))
|
||||
+ float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0))
|
||||
+ float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0))
|
||||
)
|
||||
audio_stage_wait_ms = (
|
||||
float(load_profiled.queue_ms)
|
||||
+ float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0))
|
||||
+ float(ref_spec_profiled.queue_ms)
|
||||
)
|
||||
finished_at = time.perf_counter()
|
||||
result = {
|
||||
"prompt_semantic": prompt_semantic,
|
||||
"refer_spec": refer_spec,
|
||||
"raw_audio": raw_audio,
|
||||
"raw_sr": raw_sr,
|
||||
"profile": {
|
||||
"audio_load_queue_ms": float(load_profiled.queue_ms),
|
||||
"audio_load_ms": float(load_profiled.run_ms),
|
||||
"audio_stage_wait_ms": float(audio_stage_wait_ms),
|
||||
"audio_stage_slots": float(
|
||||
max(
|
||||
float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||||
float(limiter_snapshot.get("slots", 0.0)),
|
||||
)
|
||||
),
|
||||
"audio_stage_inflight_peak": float(
|
||||
max(
|
||||
float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
|
||||
float(limiter_snapshot.get("peak_inflight", 0.0)),
|
||||
)
|
||||
),
|
||||
"prompt_semantic_ms": float(prompt_semantic_ms),
|
||||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
|
||||
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
||||
"prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||||
"prompt_semantic_stage_inflight_peak": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)
|
||||
),
|
||||
"prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)),
|
||||
"prompt_semantic_batch_samples": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)
|
||||
),
|
||||
"ref_spec_wait_ms": float(ref_spec_profiled.queue_ms),
|
||||
"ref_spec_ms": float(ref_spec_profiled.run_ms),
|
||||
"bundle_total_ms": float(
|
||||
load_profiled.queue_ms
|
||||
+ load_profiled.run_ms
|
||||
+ prompt_semantic_ms
|
||||
+ ref_spec_profiled.queue_ms
|
||||
+ ref_spec_profiled.run_ms
|
||||
),
|
||||
},
|
||||
}
|
||||
return ProfiledResult(
|
||||
result=result,
|
||||
submit_at=float(submit_at),
|
||||
started_at=started_at,
|
||||
finished_at=float(finished_at),
|
||||
)
|
||||
|
||||
await self.ref_audio_gate.acquire()
|
||||
try:
|
||||
if hasattr(self.tts, "extract_ref_audio_bundle_async"):
|
||||
submit_at = time.perf_counter()
|
||||
started_at = time.perf_counter()
|
||||
result = await self.tts.extract_ref_audio_bundle_async(ref_audio_path)
|
||||
finished_at = time.perf_counter()
|
||||
return ProfiledResult(
|
||||
result=result,
|
||||
submit_at=float(submit_at),
|
||||
started_at=float(started_at),
|
||||
finished_at=float(finished_at),
|
||||
)
|
||||
return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path)
|
||||
finally:
|
||||
self.ref_audio_gate.release()
|
||||
|
||||
def _release_split_stage_slot(self) -> None:
|
||||
self._mark_leave()
|
||||
if self._inflight_semaphore is not None:
|
||||
self._inflight_semaphore.release()
|
||||
self._inflight_gate.release()
|
||||
|
||||
async def prepare_cpu_stage_profiled_async(
|
||||
self,
|
||||
@ -312,9 +555,11 @@ class PrepareCoordinator:
|
||||
prepare_submit_at: float,
|
||||
) -> PreparedCpuStage:
|
||||
admission_start = time.perf_counter()
|
||||
if self._inflight_semaphore is not None:
|
||||
await self._inflight_semaphore.acquire()
|
||||
prepare_admission_wait_ms = max(0.0, (time.perf_counter() - admission_start) * 1000.0)
|
||||
admission_stats = await self._inflight_gate.acquire()
|
||||
prepare_admission_wait_ms = max(
|
||||
float(admission_stats.get("wait_ms", 0.0)),
|
||||
(time.perf_counter() - admission_start) * 1000.0,
|
||||
)
|
||||
current_inflight, peak_inflight = self._mark_enter()
|
||||
prepare_start = time.perf_counter()
|
||||
prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang)
|
||||
@ -382,10 +627,28 @@ class PrepareCoordinator:
|
||||
"prompt_text_parallel_future_run_tail_after_target_ms": 0.0,
|
||||
"prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms,
|
||||
"prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms,
|
||||
"prompt_text_cpu_admission_wait_ms": float(
|
||||
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_text_cpu_backpressure_wait_ms": float(
|
||||
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_text_cpu_capacity_wait_ms": float(
|
||||
(cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0)
|
||||
),
|
||||
"prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms,
|
||||
"prompt_text_feature_run_ms": prompt_feature_profiled.run_ms,
|
||||
"text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms,
|
||||
"text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms,
|
||||
"text_cpu_admission_wait_ms": float(
|
||||
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0)
|
||||
),
|
||||
"text_cpu_backpressure_wait_ms": float(
|
||||
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0)
|
||||
),
|
||||
"text_cpu_capacity_wait_ms": float(
|
||||
(cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0)
|
||||
),
|
||||
"text_feature_queue_ms": target_feature_profiled.queue_ms,
|
||||
"text_feature_run_ms": target_feature_profiled.run_ms,
|
||||
"ref_audio_task_queue_ms": ref_audio_profiled.queue_ms,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
@ -51,6 +52,8 @@ class RefSemanticTask:
|
||||
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
|
||||
created_at: float = field(default_factory=time.perf_counter)
|
||||
done_event: threading.Event = field(default_factory=threading.Event)
|
||||
done_loop: asyncio.AbstractEventLoop | None = None
|
||||
done_future: asyncio.Future | None = None
|
||||
result_prompt_semantic: torch.Tensor | None = None
|
||||
error: Exception | None = None
|
||||
profile: Dict[str, float] = field(default_factory=dict)
|
||||
@ -115,6 +118,41 @@ class PrepareRefSemanticBatchWorker:
|
||||
assert task.result_prompt_semantic is not None
|
||||
return task.result_prompt_semantic, dict(task.profile)
|
||||
|
||||
async def submit_async(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
loop = asyncio.get_running_loop()
|
||||
task = RefSemanticTask(
|
||||
raw_audio=raw_audio,
|
||||
raw_sr=int(raw_sr),
|
||||
done_loop=loop,
|
||||
done_future=loop.create_future(),
|
||||
)
|
||||
with self.condition:
|
||||
self.pending_tasks.append(task)
|
||||
self.total_submitted += 1
|
||||
if len(self.pending_tasks) > self.pending_peak:
|
||||
self.pending_peak = len(self.pending_tasks)
|
||||
self.condition.notify_all()
|
||||
return await task.done_future
|
||||
|
||||
@staticmethod
|
||||
def _resolve_done_future(task: RefSemanticTask) -> None:
|
||||
if task.done_future is None or task.done_future.done():
|
||||
return
|
||||
if task.error is not None:
|
||||
task.done_future.set_exception(task.error)
|
||||
return
|
||||
assert task.result_prompt_semantic is not None
|
||||
task.done_future.set_result((task.result_prompt_semantic, dict(task.profile)))
|
||||
|
||||
def _notify_task_done(self, task: RefSemanticTask) -> None:
|
||||
task.done_event.set()
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_done_future, task)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def snapshot(self) -> Dict[str, int]:
|
||||
with self.condition:
|
||||
return {
|
||||
@ -247,7 +285,7 @@ class PrepareRefSemanticBatchWorker:
|
||||
for task in batch:
|
||||
if task.result_prompt_semantic is not None:
|
||||
task.profile["prompt_semantic_scatter_ms"] = float(scatter_ms)
|
||||
task.done_event.set()
|
||||
self._notify_task_done(task)
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
while True:
|
||||
@ -257,6 +295,6 @@ class PrepareRefSemanticBatchWorker:
|
||||
except Exception as exc: # noqa: PERF203
|
||||
for task in batch:
|
||||
task.error = exc
|
||||
task.done_event.set()
|
||||
self._notify_task_done(task)
|
||||
finally:
|
||||
self._finalize_batch(batch)
|
||||
|
||||
215
GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py
Normal file
215
GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py
Normal file
@ -0,0 +1,215 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Deque, Dict, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextCpuTask:
|
||||
text: str
|
||||
language: str
|
||||
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
|
||||
created_at: float = field(default_factory=time.perf_counter)
|
||||
enqueued_at: float = 0.0
|
||||
admission_wait_ms: float = 0.0
|
||||
backpressure_wait_ms: float = 0.0
|
||||
capacity_wait_ms: float = 0.0
|
||||
pending_depth_on_enqueue: int = 0
|
||||
done_event: threading.Event = field(default_factory=threading.Event)
|
||||
done_loop: asyncio.AbstractEventLoop | None = None
|
||||
done_future: asyncio.Future | None = None
|
||||
result: Any = None
|
||||
error: Exception | None = None
|
||||
profile: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PrepareTextCpuWorker:
|
||||
def __init__(
|
||||
self,
|
||||
process_fn: Callable[[str, str], Any],
|
||||
worker_count: int,
|
||||
max_pending_tasks: int = 0,
|
||||
admission_poll_ms: int = 1,
|
||||
admission_controller: Callable[[], Dict[str, float | int | bool]] | None = None,
|
||||
) -> None:
|
||||
self.process_fn = process_fn
|
||||
self.worker_count = max(1, int(worker_count))
|
||||
self.max_pending_tasks = max(0, int(max_pending_tasks))
|
||||
self.admission_poll_s = max(0.0005, float(max(1, int(admission_poll_ms))) / 1000.0)
|
||||
self.admission_controller = admission_controller
|
||||
|
||||
self.condition = threading.Condition()
|
||||
self.pending_tasks: Deque[TextCpuTask] = deque()
|
||||
self.pending_peak = 0
|
||||
self.total_submitted = 0
|
||||
self.total_finished = 0
|
||||
self.active_workers = 0
|
||||
self.active_workers_peak = 0
|
||||
self.admission_wait_total_ms = 0.0
|
||||
self.admission_wait_peak_ms = 0.0
|
||||
self.backpressure_wait_total_ms = 0.0
|
||||
self.backpressure_wait_peak_ms = 0.0
|
||||
self.capacity_wait_total_ms = 0.0
|
||||
self.capacity_wait_peak_ms = 0.0
|
||||
self.backpressure_blocked_total = 0
|
||||
|
||||
self.worker_threads = [
|
||||
threading.Thread(target=self._run_loop, name=f"prepare-text-cpu-worker-{index}", daemon=True)
|
||||
for index in range(self.worker_count)
|
||||
]
|
||||
for thread in self.worker_threads:
|
||||
thread.start()
|
||||
|
||||
def _can_enqueue_locked(self) -> bool:
|
||||
if self.max_pending_tasks <= 0:
|
||||
return True
|
||||
return (len(self.pending_tasks) + self.active_workers) < self.max_pending_tasks
|
||||
|
||||
def _get_admission_state(self) -> Dict[str, float | int | bool]:
|
||||
if self.admission_controller is None:
|
||||
return {"blocked": False}
|
||||
try:
|
||||
state = dict(self.admission_controller() or {})
|
||||
except Exception:
|
||||
return {"blocked": False}
|
||||
state["blocked"] = bool(state.get("blocked", False))
|
||||
return state
|
||||
|
||||
def _record_enqueue_locked(
|
||||
self,
|
||||
task: TextCpuTask,
|
||||
*,
|
||||
admission_wait_ms: float,
|
||||
backpressure_wait_ms: float,
|
||||
capacity_wait_ms: float,
|
||||
) -> None:
|
||||
task.admission_wait_ms = float(max(0.0, admission_wait_ms))
|
||||
task.backpressure_wait_ms = float(max(0.0, backpressure_wait_ms))
|
||||
task.capacity_wait_ms = float(max(0.0, capacity_wait_ms))
|
||||
task.enqueued_at = time.perf_counter()
|
||||
task.pending_depth_on_enqueue = int(len(self.pending_tasks))
|
||||
self.pending_tasks.append(task)
|
||||
self.total_submitted += 1
|
||||
self.admission_wait_total_ms += task.admission_wait_ms
|
||||
self.admission_wait_peak_ms = max(self.admission_wait_peak_ms, task.admission_wait_ms)
|
||||
self.backpressure_wait_total_ms += task.backpressure_wait_ms
|
||||
self.backpressure_wait_peak_ms = max(self.backpressure_wait_peak_ms, task.backpressure_wait_ms)
|
||||
self.capacity_wait_total_ms += task.capacity_wait_ms
|
||||
self.capacity_wait_peak_ms = max(self.capacity_wait_peak_ms, task.capacity_wait_ms)
|
||||
if task.backpressure_wait_ms > 0.0:
|
||||
self.backpressure_blocked_total += 1
|
||||
if len(self.pending_tasks) > self.pending_peak:
|
||||
self.pending_peak = len(self.pending_tasks)
|
||||
self.condition.notify_all()
|
||||
|
||||
async def _enqueue_task_async(self, task: TextCpuTask) -> None:
|
||||
admission_started = time.perf_counter()
|
||||
backpressure_wait_ms = 0.0
|
||||
capacity_wait_ms = 0.0
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
admission_state = self._get_admission_state()
|
||||
blocked = bool(admission_state.get("blocked", False))
|
||||
with self.condition:
|
||||
if not blocked and self._can_enqueue_locked():
|
||||
self._record_enqueue_locked(
|
||||
task,
|
||||
admission_wait_ms=(time.perf_counter() - admission_started) * 1000.0,
|
||||
backpressure_wait_ms=backpressure_wait_ms,
|
||||
capacity_wait_ms=capacity_wait_ms,
|
||||
)
|
||||
return
|
||||
await asyncio.sleep(self.admission_poll_s)
|
||||
waited_ms = (time.perf_counter() - loop_start) * 1000.0
|
||||
if blocked:
|
||||
backpressure_wait_ms += waited_ms
|
||||
else:
|
||||
capacity_wait_ms += waited_ms
|
||||
|
||||
def submit(self, text: str, language: str) -> Tuple[Any, Dict[str, float]]:
|
||||
task = TextCpuTask(text=str(text), language=str(language))
|
||||
asyncio.run(self._enqueue_task_async(task))
|
||||
task.done_event.wait()
|
||||
if task.error is not None:
|
||||
raise task.error
|
||||
return task.result, dict(task.profile)
|
||||
|
||||
async def submit_async(self, text: str, language: str) -> Tuple[Any, Dict[str, float]]:
|
||||
loop = asyncio.get_running_loop()
|
||||
task = TextCpuTask(
|
||||
text=str(text),
|
||||
language=str(language),
|
||||
done_loop=loop,
|
||||
done_future=loop.create_future(),
|
||||
)
|
||||
await self._enqueue_task_async(task)
|
||||
return await task.done_future
|
||||
|
||||
@staticmethod
|
||||
def _resolve_done_future(task: TextCpuTask) -> None:
|
||||
if task.done_future is None or task.done_future.done():
|
||||
return
|
||||
if task.error is not None:
|
||||
task.done_future.set_exception(task.error)
|
||||
return
|
||||
task.done_future.set_result((task.result, dict(task.profile)))
|
||||
|
||||
def _notify_task_done(self, task: TextCpuTask) -> None:
|
||||
task.done_event.set()
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_done_future, task)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def snapshot(self) -> Dict[str, int | float]:
|
||||
with self.condition:
|
||||
return {
|
||||
"worker_count": int(self.worker_count),
|
||||
"pending": int(len(self.pending_tasks)),
|
||||
"pending_peak": int(self.pending_peak),
|
||||
"active_workers": int(self.active_workers),
|
||||
"active_workers_peak": int(self.active_workers_peak),
|
||||
"total_submitted": int(self.total_submitted),
|
||||
"total_finished": int(self.total_finished),
|
||||
"max_pending_tasks": int(self.max_pending_tasks),
|
||||
"admission_wait_total_ms": float(self.admission_wait_total_ms),
|
||||
"admission_wait_peak_ms": float(self.admission_wait_peak_ms),
|
||||
"backpressure_wait_total_ms": float(self.backpressure_wait_total_ms),
|
||||
"backpressure_wait_peak_ms": float(self.backpressure_wait_peak_ms),
|
||||
"capacity_wait_total_ms": float(self.capacity_wait_total_ms),
|
||||
"capacity_wait_peak_ms": float(self.capacity_wait_peak_ms),
|
||||
"backpressure_blocked_total": int(self.backpressure_blocked_total),
|
||||
}
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
while True:
|
||||
with self.condition:
|
||||
while not self.pending_tasks:
|
||||
self.condition.wait()
|
||||
task = self.pending_tasks.popleft()
|
||||
self.active_workers += 1
|
||||
self.active_workers_peak = max(self.active_workers_peak, self.active_workers)
|
||||
started_at = time.perf_counter()
|
||||
try:
|
||||
task.result = self.process_fn(task.text, task.language)
|
||||
task.profile = {
|
||||
"text_cpu_admission_wait_ms": float(task.admission_wait_ms),
|
||||
"text_cpu_backpressure_wait_ms": float(task.backpressure_wait_ms),
|
||||
"text_cpu_capacity_wait_ms": float(task.capacity_wait_ms),
|
||||
"text_cpu_queue_wait_ms": max(0.0, (started_at - task.enqueued_at) * 1000.0),
|
||||
"text_cpu_pending_depth_on_enqueue": float(task.pending_depth_on_enqueue),
|
||||
"text_cpu_run_ms": max(0.0, (time.perf_counter() - started_at) * 1000.0),
|
||||
}
|
||||
except Exception as exc: # noqa: PERF203
|
||||
task.error = exc
|
||||
finally:
|
||||
with self.condition:
|
||||
self.active_workers = max(0, self.active_workers - 1)
|
||||
self.total_finished += 1
|
||||
self.condition.notify_all()
|
||||
self._notify_task_done(task)
|
||||
@ -421,6 +421,55 @@ def _iter_contiguous_sampling_groups(
|
||||
return groups
|
||||
|
||||
|
||||
def _uniform_sampling_group_key(active_batch: T2SActiveBatch) -> Optional[Tuple[int, float, float, float, bool]]:
|
||||
if not active_batch.states:
|
||||
return None
|
||||
if active_batch.step_indices.numel() <= 0:
|
||||
return None
|
||||
first_step_index = int(active_batch.step_indices[0].item())
|
||||
if bool((active_batch.step_indices != first_step_index).any().item()):
|
||||
return None
|
||||
first_state = active_batch.states[0]
|
||||
first_key = _sampling_group_key(
|
||||
top_k=first_state.top_k,
|
||||
top_p=first_state.top_p,
|
||||
temperature=first_state.temperature,
|
||||
repetition_penalty=first_state.repetition_penalty,
|
||||
trim_eos=first_step_index < 11,
|
||||
)
|
||||
for state in active_batch.states[1:]:
|
||||
if (
|
||||
state.top_k != first_state.top_k
|
||||
or state.top_p != first_state.top_p
|
||||
or state.temperature != first_state.temperature
|
||||
or state.repetition_penalty != first_state.repetition_penalty
|
||||
):
|
||||
return None
|
||||
return first_key
|
||||
|
||||
|
||||
def _batched_sample_uniform(
|
||||
logits: torch.Tensor,
|
||||
histories: Sequence[torch.LongTensor],
|
||||
sampling_key: Tuple[int, float, float, float, bool],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
top_k, top_p, temperature, repetition_penalty, trim_eos = sampling_key
|
||||
sample_logits = logits[:, :-1] if trim_eos else logits
|
||||
padded_histories, history_mask = _pad_token_sequences(histories)
|
||||
probs = logits_to_probs(
|
||||
logits=sample_logits,
|
||||
previous_tokens=padded_histories,
|
||||
previous_token_mask=history_mask,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
temperature=temperature,
|
||||
)
|
||||
sampled = multinomial_sample_one_no_sync(probs)
|
||||
argmax_tokens = torch.argmax(sample_logits, dim=-1)
|
||||
return sampled, argmax_tokens
|
||||
|
||||
|
||||
def _batched_sample_by_group(
|
||||
logits: torch.Tensor,
|
||||
histories: Sequence[torch.LongTensor],
|
||||
@ -594,27 +643,59 @@ def _sample_per_request(
|
||||
keep_indices: List[int] = []
|
||||
updated_sequences: List[torch.LongTensor] = []
|
||||
|
||||
sampling_keys = [
|
||||
_sampling_group_key(
|
||||
top_k=state.top_k,
|
||||
top_p=state.top_p,
|
||||
temperature=state.temperature,
|
||||
repetition_penalty=state.repetition_penalty,
|
||||
trim_eos=int(active_batch.step_indices[batch_index].item()) < 11,
|
||||
uniform_sampling_key = _uniform_sampling_group_key(active_batch)
|
||||
sampled_items: List[torch.Tensor]
|
||||
argmax_tokens: List[int]
|
||||
sampled_token_tensor: Optional[torch.Tensor] = None
|
||||
argmax_token_tensor: Optional[torch.Tensor] = None
|
||||
if uniform_sampling_key is not None:
|
||||
sampled_tensor, argmax_tensor = _batched_sample_uniform(
|
||||
logits=logits,
|
||||
histories=active_batch.y_sequences,
|
||||
sampling_key=uniform_sampling_key,
|
||||
)
|
||||
sampled_token_tensor = sampled_tensor.view(-1)
|
||||
argmax_token_tensor = argmax_tensor.view(-1)
|
||||
if (
|
||||
all(state.early_stop_num == -1 for state in active_batch.states)
|
||||
and int(active_batch.step_indices[0].item()) + 1 < max_steps
|
||||
and not bool(sampled_token_tensor.eq(model.EOS).any().item())
|
||||
and not bool(argmax_token_tensor.eq(model.EOS).any().item())
|
||||
):
|
||||
return (
|
||||
[],
|
||||
list(range(len(active_batch.states))),
|
||||
[torch.cat([history, sampled_token_tensor[index : index + 1]], dim=0) for index, history in enumerate(active_batch.y_sequences)],
|
||||
)
|
||||
sampled_items = [sampled_tensor[index : index + 1] for index in range(sampled_tensor.shape[0])]
|
||||
argmax_tokens = [int(item) for item in argmax_tensor.tolist()]
|
||||
else:
|
||||
sampling_keys = [
|
||||
_sampling_group_key(
|
||||
top_k=state.top_k,
|
||||
top_p=state.top_p,
|
||||
temperature=state.temperature,
|
||||
repetition_penalty=state.repetition_penalty,
|
||||
trim_eos=int(active_batch.step_indices[batch_index].item()) < 11,
|
||||
)
|
||||
for batch_index, state in enumerate(active_batch.states)
|
||||
]
|
||||
sampled_items, argmax_tokens = _batched_sample_by_group(
|
||||
logits=logits,
|
||||
histories=active_batch.y_sequences,
|
||||
sampling_keys=sampling_keys,
|
||||
)
|
||||
for batch_index, state in enumerate(active_batch.states)
|
||||
]
|
||||
sampled_items, argmax_tokens = _batched_sample_by_group(
|
||||
logits=logits,
|
||||
histories=active_batch.y_sequences,
|
||||
sampling_keys=sampling_keys,
|
||||
)
|
||||
for batch_index, state in enumerate(active_batch.states):
|
||||
step_index = int(active_batch.step_indices[batch_index].item())
|
||||
current_history = active_batch.y_sequences[batch_index]
|
||||
sampled = sampled_items[batch_index]
|
||||
sampled_token = int(sampled[0, 0].item())
|
||||
argmax_token = argmax_tokens[batch_index]
|
||||
if sampled_token_tensor is not None and argmax_token_tensor is not None:
|
||||
sampled = sampled_token_tensor[batch_index : batch_index + 1]
|
||||
sampled_token = int(sampled_token_tensor[batch_index].item())
|
||||
argmax_token = int(argmax_token_tensor[batch_index].item())
|
||||
else:
|
||||
sampled = sampled_items[batch_index]
|
||||
sampled_token = int(sampled[0, 0].item())
|
||||
argmax_token = argmax_tokens[batch_index]
|
||||
new_history = torch.cat([current_history, sampled.view(-1)], dim=0)
|
||||
|
||||
finish_reason: Optional[str] = None
|
||||
@ -690,6 +771,13 @@ def decode_one_step(
|
||||
finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps)
|
||||
if len(keep_indices) == 0:
|
||||
return None, finished_items
|
||||
if len(keep_indices) == len(active_batch.request_ids):
|
||||
active_batch.y_sequences = updated_sequences
|
||||
active_batch.step_indices = active_batch.step_indices + 1
|
||||
if not was_prefill and active_batch.kv_lens is not None:
|
||||
active_batch.kv_lens = active_batch.kv_lens + 1
|
||||
active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences)
|
||||
return active_batch, finished_items
|
||||
|
||||
device = logits.device
|
||||
keep_tensor = torch.LongTensor(keep_indices).to(device)
|
||||
|
||||
@ -314,6 +314,7 @@ def build_scheduler_submit_headers(
|
||||
"X-Prepare-Text-Pair-Wall-Ms": format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)),
|
||||
"X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))),
|
||||
"X-Prepare-Engine-GPU-Queue-Wait-Ms": format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)),
|
||||
"X-Prepare-Engine-GPU-Batch-Size": str(int(prepare_profile.get("engine_gpu_prepare_batch_size", 0.0))),
|
||||
"X-Prepare-Audio-Load-Ms": format_ms_header(prepare_profile.get("audio_load_ms", 0.0)),
|
||||
"X-Prepare-Audio-Stage-Wait-Ms": format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Semantic-Ms": format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)),
|
||||
|
||||
@ -44,6 +44,16 @@ class EngineTaskQueueOwner:
|
||||
return None
|
||||
return self.queue.popleft()
|
||||
|
||||
def pop_left_many(self, max_items: int) -> List[Any]:
|
||||
limit = max(1, int(max_items))
|
||||
with self.condition:
|
||||
if not self.queue:
|
||||
return []
|
||||
selected: List[Any] = []
|
||||
while self.queue and len(selected) < limit:
|
||||
selected.append(self.queue.popleft())
|
||||
return selected
|
||||
|
||||
def mark_completed(self, count: int = 1, *, notify: bool = False) -> None:
|
||||
if count <= 0:
|
||||
return
|
||||
@ -315,6 +325,7 @@ class EngineGpuPrepareTask:
|
||||
engine_request_id: str | None
|
||||
enqueue_time: float
|
||||
queue_wait_ms: float = 0.0
|
||||
admission_wait_ms: float = 0.0
|
||||
error: str | None = None
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
@ -9,6 +10,19 @@ from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineGpuPrepare
|
||||
|
||||
|
||||
class EnginePrepareStageMixin:
|
||||
async def _wait_prepare_queue_admission(self) -> float:
|
||||
soft_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_SOFT_MAX", "0")))
|
||||
if soft_max <= 0:
|
||||
return 0.0
|
||||
poll_s = max(
|
||||
0.0005,
|
||||
float(max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_ADMISSION_POLL_MS", "1")))) / 1000.0,
|
||||
)
|
||||
wait_start = time.perf_counter()
|
||||
while self.prepare_queue_owner.waiting_count() >= soft_max:
|
||||
await asyncio.sleep(poll_s)
|
||||
return max(0.0, (time.perf_counter() - wait_start) * 1000.0)
|
||||
|
||||
async def prepare_state_via_engine_gpu_queue(
|
||||
self,
|
||||
*,
|
||||
@ -16,12 +30,14 @@ class EnginePrepareStageMixin:
|
||||
prepare_submit_at: float,
|
||||
engine_request_id: str | None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
prepare_queue_admission_wait_ms = await self._wait_prepare_queue_admission()
|
||||
cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at)
|
||||
if engine_request_id not in [None, ""]:
|
||||
self.update_request_state(
|
||||
str(engine_request_id),
|
||||
EngineStatus.GPU_PREPARING,
|
||||
{
|
||||
"engine_prepare_queue_admission_wait_ms": float(prepare_queue_admission_wait_ms),
|
||||
"prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms),
|
||||
"prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms),
|
||||
"text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms),
|
||||
@ -37,31 +53,44 @@ class EnginePrepareStageMixin:
|
||||
done_future=done_future,
|
||||
engine_request_id=engine_request_id or spec.request_id,
|
||||
enqueue_time=time.perf_counter(),
|
||||
admission_wait_ms=float(prepare_queue_admission_wait_ms),
|
||||
)
|
||||
self.prepare_queue_owner.enqueue(task)
|
||||
self.notify_arbiter()
|
||||
return await done_future
|
||||
|
||||
def run_engine_prepare_once(self) -> bool:
|
||||
task = self.prepare_queue_owner.pop_left()
|
||||
if task is None:
|
||||
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
|
||||
tasks = self.prepare_queue_owner.pop_left_many(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))
|
||||
if not tasks:
|
||||
return False
|
||||
queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0)
|
||||
try:
|
||||
state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run(
|
||||
self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage)
|
||||
)
|
||||
now = time.perf_counter()
|
||||
queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks]
|
||||
batch_results = asyncio.run(
|
||||
self.scheduler_worker.prepare_gpu_stages_profiled_async([task.cpu_stage for task in tasks])
|
||||
)
|
||||
completed_count = 0
|
||||
for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results):
|
||||
if isinstance(result, Exception):
|
||||
task.error = str(result)
|
||||
self.fail_request_state(task.engine_request_id or task.request_id, str(result))
|
||||
self._notify_prepare_error(task, result)
|
||||
completed_count += 1
|
||||
continue
|
||||
state, prepare_exec_started_at, prepare_exec_finished_at = result
|
||||
state.prepare_profile["engine_prepare_queue_admission_wait_ms"] = float(task.admission_wait_ms)
|
||||
state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms)
|
||||
state.prepare_profile["engine_gpu_prepare_batch_size"] = float(len(tasks))
|
||||
if task.engine_request_id not in [None, ""]:
|
||||
self.merge_request_state_profile(
|
||||
str(task.engine_request_id),
|
||||
{"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)},
|
||||
{
|
||||
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
|
||||
"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms),
|
||||
"engine_gpu_prepare_batch_size": float(len(tasks)),
|
||||
},
|
||||
)
|
||||
self.prepare_queue_owner.mark_completed(1)
|
||||
self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at))
|
||||
return True
|
||||
except Exception as exc:
|
||||
task.error = str(exc)
|
||||
self.fail_request_state(task.engine_request_id or task.request_id, str(exc))
|
||||
self._notify_prepare_error(task, exc)
|
||||
return True
|
||||
completed_count += 1
|
||||
self.prepare_queue_owner.mark_completed(completed_count)
|
||||
return True
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
@ -32,6 +33,11 @@ class WorkerPrepareExecutor:
|
||||
def get_max_inflight(self) -> int:
|
||||
return int(self.coordinator.snapshot().get("max_inflight", 0))
|
||||
|
||||
def get_batch_policy(self) -> Dict[str, int]:
|
||||
return {
|
||||
"prepare_batch_max_items": max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_BATCH_MAX_ITEMS", 8))),
|
||||
}
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return int(self.coordinator.snapshot().get("inflight", 0)) <= 0
|
||||
|
||||
@ -69,3 +75,17 @@ class WorkerPrepareExecutor:
|
||||
return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
|
||||
async def prepare_gpu_stages_profiled_async(
|
||||
self,
|
||||
cpu_stages: List[PreparedCpuStage],
|
||||
) -> List[tuple[T2SRequestState, float, float] | Exception]:
|
||||
try:
|
||||
return list(
|
||||
await asyncio.gather(
|
||||
*[self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) for cpu_stage in cpu_stages],
|
||||
return_exceptions=True,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
|
||||
@ -32,6 +32,9 @@ class WorkerSubmitLifecycleMixin:
|
||||
def get_finalize_batch_policy(self) -> Dict[str, Any]:
|
||||
return dict(self.finalize_executor.get_batch_policy())
|
||||
|
||||
def get_prepare_batch_policy(self) -> Dict[str, int]:
|
||||
return dict(self.prepare_executor.get_batch_policy())
|
||||
|
||||
def get_decode_runtime_counters(self) -> Dict[str, int]:
|
||||
with self.condition:
|
||||
return self.decode_runtime_tracker.get_counters()
|
||||
@ -258,3 +261,9 @@ class WorkerSubmitLifecycleMixin:
|
||||
cpu_stage: PreparedCpuStage,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage)
|
||||
|
||||
async def prepare_gpu_stages_profiled_async(
|
||||
self,
|
||||
cpu_stages: List[PreparedCpuStage],
|
||||
) -> List[tuple[T2SRequestState, float, float] | Exception]:
|
||||
return await self.prepare_executor.prepare_gpu_stages_profiled_async(cpu_stages)
|
||||
|
||||
1
third_party/g2pw-cu
vendored
Submodule
1
third_party/g2pw-cu
vendored
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit a53cf4eed5759f7b5d4563ce6e4b13557e054d98
|
||||
Loading…
x
Reference in New Issue
Block a user