mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-20 09:28:11 +08:00
Add unified engine API modules for direct and scheduler-based TTS processing
Introduce new modules including unified_engine_api_direct, unified_engine_api_profile, unified_engine_api_request, and unified_engine_api_scheduler. These additions enhance the TTS system by providing structured interfaces for direct TTS execution and scheduler-based processing. The new components support improved request handling, profiling, and state management, significantly enhancing the architecture and maintainability of the TTS framework.
This commit is contained in:
parent
b046a093d3
commit
3fd4f48651
File diff suppressed because it is too large
Load Diff
441
GPT_SoVITS/TTS_infer_pack/unified_engine_api_direct.py
Normal file
441
GPT_SoVITS/TTS_infer_pack/unified_engine_api_direct.py
Normal file
@ -0,0 +1,441 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, wave_header_chunk
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, EngineStatus, NormalizedEngineRequest, SchedulerPendingJob
|
||||
|
||||
|
||||
class EngineApiDirectFlow:
|
||||
def __init__(self, api: Any) -> None:
|
||||
self.api = api
|
||||
|
||||
def _iter_legacy_direct_tts_bytes(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
) -> Generator[bytes, None, None]:
|
||||
payload = normalized.to_payload()
|
||||
media_type = normalized.media_type
|
||||
request_id = normalized.request_id
|
||||
request_start = time.perf_counter()
|
||||
chunk_count = 0
|
||||
stream_total_bytes = 0
|
||||
first_chunk_ms: float | None = None
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.ACTIVE_DECODE,
|
||||
{"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason},
|
||||
)
|
||||
try:
|
||||
with self.api.direct_tts_lock:
|
||||
tts_generator = self.api.tts.run(payload)
|
||||
first_chunk = True
|
||||
current_media_type = media_type
|
||||
for sr, chunk in tts_generator:
|
||||
if first_chunk:
|
||||
first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0)
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.STREAMING,
|
||||
{
|
||||
"backend": backend,
|
||||
"backend_mode": backend,
|
||||
"fallback_reason": fallback_reason,
|
||||
"sample_rate": int(sr),
|
||||
},
|
||||
)
|
||||
if first_chunk and media_type == "wav":
|
||||
header = wave_header_chunk(sample_rate=sr)
|
||||
chunk_count += 1
|
||||
stream_total_bytes += len(header)
|
||||
yield header
|
||||
current_media_type = "raw"
|
||||
first_chunk = False
|
||||
elif first_chunk:
|
||||
first_chunk = False
|
||||
packed_chunk = pack_audio(BytesIO(), chunk, sr, current_media_type).getvalue()
|
||||
chunk_count += 1
|
||||
stream_total_bytes += len(packed_chunk)
|
||||
yield packed_chunk
|
||||
except Exception as exc:
|
||||
self.api._fail_request_state(request_id, str(exc))
|
||||
raise
|
||||
self.api._complete_request_state(
|
||||
request_id,
|
||||
dict(
|
||||
self.api._build_legacy_direct_profile(
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
request_start=request_start,
|
||||
finished_at=time.perf_counter(),
|
||||
audio_bytes=stream_total_bytes,
|
||||
chunk_count=chunk_count,
|
||||
stream_total_bytes=stream_total_bytes,
|
||||
first_chunk_ms=first_chunk_ms,
|
||||
),
|
||||
streaming_completed=True,
|
||||
),
|
||||
)
|
||||
|
||||
def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool:
|
||||
if isinstance(req, NormalizedEngineRequest):
|
||||
normalized = req
|
||||
else:
|
||||
normalized = self.api._normalize_engine_request(
|
||||
req,
|
||||
request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"),
|
||||
normalize_streaming=True,
|
||||
)
|
||||
backend, _ = self.api._select_direct_backend(normalized)
|
||||
return backend == "scheduler_v1_direct"
|
||||
|
||||
def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]:
|
||||
payload = normalized.to_payload() if isinstance(normalized, NormalizedEngineRequest) else normalized
|
||||
return self.api.tts.text_preprocessor.pre_seg_text(
|
||||
str(payload["text"]),
|
||||
str(payload["text_lang"]),
|
||||
str(payload.get("text_split_method", "cut5")),
|
||||
)
|
||||
|
||||
def _build_segment_request(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
request_id: str,
|
||||
text: str,
|
||||
) -> NormalizedEngineRequest:
|
||||
payload = normalized.to_payload()
|
||||
payload["request_id"] = request_id
|
||||
payload["text"] = text
|
||||
payload["streaming_mode"] = False
|
||||
payload["return_fragment"] = False
|
||||
payload["fixed_length_chunk"] = False
|
||||
payload["response_streaming"] = False
|
||||
return self.api._normalize_engine_request(payload, error_prefix="segment request 参数非法: ")
|
||||
|
||||
async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution:
|
||||
request_start = time.perf_counter()
|
||||
request_id = normalized.request_id
|
||||
media_type = normalized.media_type
|
||||
segment_texts = self._segment_direct_text(normalized)
|
||||
if not segment_texts:
|
||||
raise ValueError("text preprocessing returned no valid segments")
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.CPU_PREPARING,
|
||||
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)},
|
||||
)
|
||||
segment_specs = []
|
||||
for segment_index, segment_text in enumerate(segment_texts):
|
||||
segment_request = self._build_segment_request(
|
||||
normalized,
|
||||
request_id=f"{request_id}_seg_{segment_index:03d}",
|
||||
text=segment_text,
|
||||
)
|
||||
segment_specs.append(self.api._build_scheduler_submit_spec(segment_request))
|
||||
|
||||
prepared_items = await asyncio.gather(
|
||||
*[
|
||||
self.api._prepare_state_via_engine_gpu_queue(
|
||||
spec=spec,
|
||||
prepare_submit_at=time.perf_counter(),
|
||||
engine_request_id=None,
|
||||
)
|
||||
for spec in segment_specs
|
||||
]
|
||||
)
|
||||
prepare_profiles: List[Dict[str, Any]] = []
|
||||
loop = asyncio.get_running_loop()
|
||||
done_futures: List[asyncio.Future] = []
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.READY_FOR_PREFILL,
|
||||
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_specs)},
|
||||
)
|
||||
for spec, (state, prepare_exec_started_at, prepare_exec_finished_at) in zip(segment_specs, prepared_items):
|
||||
prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0)
|
||||
prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms))
|
||||
prepare_profiles.append(
|
||||
{
|
||||
"request_id": spec.request_id,
|
||||
"prepare_wall_ms": prepare_wall_ms,
|
||||
"prepare_profile_total_ms": prepare_profile_total_ms,
|
||||
"prepare_profile": dict(state.prepare_profile),
|
||||
}
|
||||
)
|
||||
done_future = loop.create_future()
|
||||
done_futures.append(done_future)
|
||||
await self.api._enqueue_prepared_state_for_dispatch(
|
||||
state=state,
|
||||
speed_factor=float(normalized.speed_factor),
|
||||
sample_steps=int(normalized.sample_steps),
|
||||
media_type=media_type,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=None,
|
||||
timeout_sec=normalized.timeout_sec,
|
||||
)
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.ACTIVE_DECODE,
|
||||
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"},
|
||||
)
|
||||
timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)
|
||||
jobs: List[SchedulerPendingJob] = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=timeout_sec))
|
||||
for profile_item, job in zip(prepare_profiles, jobs):
|
||||
profile_item["engine_policy_wait_ms"] = float(job.engine_policy_wait_ms)
|
||||
profile_item["engine_dispatch_wait_ms"] = float(job.engine_dispatch_wait_ms)
|
||||
self.api._merge_request_state_profile(
|
||||
request_id,
|
||||
{
|
||||
"engine_policy_wait_ms": sum(float(job.engine_policy_wait_ms) for job in jobs),
|
||||
"engine_dispatch_wait_ms": sum(float(job.engine_dispatch_wait_ms) for job in jobs),
|
||||
"prepare_aggregate": self.api._aggregate_numeric_dicts([item["prepare_profile"] for item in prepare_profiles]),
|
||||
},
|
||||
)
|
||||
|
||||
sample_rate: int | None = None
|
||||
audio_parts: List[np.ndarray] = []
|
||||
worker_profiles: List[Dict[str, Any]] = []
|
||||
fragment_interval = float(normalized.fragment_interval)
|
||||
silence_chunk: Optional[np.ndarray] = None
|
||||
for job in jobs:
|
||||
if job.error is not None:
|
||||
raise RuntimeError(job.error)
|
||||
if job.audio_data is None or job.sample_rate is None or job.result is None:
|
||||
raise RuntimeError(f"{job.request_id} finished without audio result")
|
||||
if sample_rate is None:
|
||||
sample_rate = int(job.sample_rate)
|
||||
silence_samples = int(fragment_interval * float(sample_rate))
|
||||
if silence_samples > 0:
|
||||
silence_chunk = np.zeros(silence_samples, dtype=np.int16)
|
||||
elif int(job.sample_rate) != sample_rate:
|
||||
raise RuntimeError("segment sample rate mismatch")
|
||||
audio_parts.append(job.audio_data)
|
||||
if silence_chunk is not None:
|
||||
audio_parts.append(silence_chunk.copy())
|
||||
worker_profiles.append(dict(job.result))
|
||||
if sample_rate is None or not audio_parts:
|
||||
raise RuntimeError("direct scheduler backend produced no audio")
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.FINALIZING,
|
||||
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"},
|
||||
)
|
||||
merged_audio = np.concatenate(audio_parts, axis=0)
|
||||
pack_start = time.perf_counter()
|
||||
audio_bytes = pack_audio(BytesIO(), merged_audio, sample_rate, media_type).getvalue()
|
||||
pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0)
|
||||
direct_profile = self.api._build_direct_scheduler_profile(
|
||||
backend="scheduler_v1_direct",
|
||||
request_start=request_start,
|
||||
response_ready_at=time.perf_counter(),
|
||||
audio_bytes=len(audio_bytes),
|
||||
sample_rate=int(sample_rate),
|
||||
segment_texts=segment_texts,
|
||||
prepare_profiles=prepare_profiles,
|
||||
worker_profiles=worker_profiles,
|
||||
pack_ms=pack_ms,
|
||||
response_overhead_ms=0.0,
|
||||
)
|
||||
self.api._complete_request_state(
|
||||
request_id,
|
||||
dict(direct_profile, streaming_completed=False),
|
||||
)
|
||||
return DirectTTSExecution(
|
||||
media_type=media_type,
|
||||
streaming=False,
|
||||
audio_bytes=audio_bytes,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
def _run_legacy_direct_tts_blocking(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
) -> DirectTTSExecution:
|
||||
normalized_payload = normalized.to_payload()
|
||||
request_id = normalized.request_id
|
||||
media_type = normalized.media_type
|
||||
request_start = time.perf_counter()
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.ACTIVE_DECODE,
|
||||
{"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason},
|
||||
)
|
||||
with self.api.direct_tts_lock:
|
||||
tts_generator = self.api.tts.run(normalized_payload)
|
||||
try:
|
||||
sr, audio_data = next(tts_generator)
|
||||
except Exception as exc:
|
||||
self.api._fail_request_state(request_id, str(exc))
|
||||
raise
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.FINALIZING,
|
||||
{"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason},
|
||||
)
|
||||
pack_start = time.perf_counter()
|
||||
packed_audio = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||
pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0)
|
||||
self.api._complete_request_state(
|
||||
request_id,
|
||||
dict(
|
||||
self.api._build_legacy_direct_profile(
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
request_start=request_start,
|
||||
finished_at=time.perf_counter(),
|
||||
sample_rate=int(sr),
|
||||
audio_bytes=len(packed_audio),
|
||||
pack_ms=pack_ms,
|
||||
),
|
||||
streaming_completed=False,
|
||||
),
|
||||
)
|
||||
return DirectTTSExecution(
|
||||
media_type=media_type,
|
||||
streaming=False,
|
||||
audio_bytes=packed_audio,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
async def _run_direct_tts_via_legacy_backend(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
) -> DirectTTSExecution:
|
||||
if normalized.response_streaming:
|
||||
return DirectTTSExecution(
|
||||
media_type=normalized.media_type,
|
||||
streaming=True,
|
||||
audio_generator=self._iter_legacy_direct_tts_bytes(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
),
|
||||
request_id=normalized.request_id,
|
||||
)
|
||||
return await asyncio.to_thread(
|
||||
self._run_legacy_direct_tts_blocking,
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
|
||||
async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution:
|
||||
normalized = self.api._normalize_engine_request(
|
||||
req,
|
||||
request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"),
|
||||
normalize_streaming=True,
|
||||
error_prefix="",
|
||||
)
|
||||
request_id = normalized.request_id
|
||||
media_type = normalized.media_type
|
||||
backend, fallback_reason = self.api._select_direct_backend(normalized)
|
||||
self.api._register_request_state(
|
||||
request_id=request_id,
|
||||
api_mode="tts",
|
||||
backend=backend,
|
||||
media_type=media_type,
|
||||
response_streaming=bool(normalized.response_streaming),
|
||||
deadline_ts=(time.perf_counter() + float(normalized.timeout_sec) if normalized.timeout_sec is not None else None),
|
||||
meta=self.api._build_request_meta(normalized.to_payload()),
|
||||
)
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.VALIDATED,
|
||||
{
|
||||
"request_source": "direct_tts",
|
||||
"selected_backend": backend,
|
||||
"fallback_reason": fallback_reason,
|
||||
},
|
||||
)
|
||||
if backend == "scheduler_v1_direct":
|
||||
try:
|
||||
return await self._run_direct_tts_via_scheduler(normalized)
|
||||
except Exception as exc:
|
||||
self.api._fail_request_state(request_id, str(exc))
|
||||
raise
|
||||
return await self._run_direct_tts_via_legacy_backend(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
|
||||
def run_direct_tts(self, req: dict) -> DirectTTSExecution:
|
||||
normalized = self.api._normalize_engine_request(
|
||||
req,
|
||||
request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"),
|
||||
normalize_streaming=True,
|
||||
error_prefix="",
|
||||
)
|
||||
request_id = normalized.request_id
|
||||
media_type = normalized.media_type
|
||||
backend, fallback_reason = self.api._select_direct_backend(normalized)
|
||||
if not self.api._has_active_request(request_id):
|
||||
self.api._register_request_state(
|
||||
request_id=request_id,
|
||||
api_mode="tts",
|
||||
backend=backend,
|
||||
media_type=media_type,
|
||||
response_streaming=bool(normalized.response_streaming),
|
||||
meta=self.api._build_request_meta(normalized.to_payload()),
|
||||
)
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.VALIDATED,
|
||||
{
|
||||
"request_source": "direct_tts",
|
||||
"selected_backend": backend,
|
||||
"fallback_reason": fallback_reason,
|
||||
},
|
||||
)
|
||||
if backend != "scheduler_v1_direct":
|
||||
if normalized.response_streaming:
|
||||
return DirectTTSExecution(
|
||||
media_type=media_type,
|
||||
streaming=True,
|
||||
audio_generator=self._iter_legacy_direct_tts_bytes(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
return self._run_legacy_direct_tts_blocking(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
if normalized.response_streaming:
|
||||
return DirectTTSExecution(
|
||||
media_type=media_type,
|
||||
streaming=True,
|
||||
audio_generator=self._iter_legacy_direct_tts_bytes(
|
||||
normalized,
|
||||
backend="legacy_direct_sync_compat",
|
||||
fallback_reason="sync_direct_compat",
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
return self._run_legacy_direct_tts_blocking(
|
||||
normalized,
|
||||
backend="legacy_direct_sync_compat",
|
||||
fallback_reason="sync_direct_compat",
|
||||
)
|
||||
387
GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py
Normal file
387
GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py
Normal file
@ -0,0 +1,387 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Sequence
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState
|
||||
|
||||
|
||||
def build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
text = payload.get("text")
|
||||
prompt_text = payload.get("prompt_text")
|
||||
return {
|
||||
"text_len": 0 if text is None else len(str(text)),
|
||||
"prompt_text_len": 0 if prompt_text is None else len(str(prompt_text)),
|
||||
"text_lang": payload.get("text_lang"),
|
||||
"prompt_lang": payload.get("prompt_lang"),
|
||||
"ref_audio_path": payload.get("ref_audio_path"),
|
||||
}
|
||||
|
||||
|
||||
def sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float:
|
||||
total = 0.0
|
||||
for item in items:
|
||||
value = item.get(key, 0.0)
|
||||
if isinstance(value, (int, float)):
|
||||
total += float(value)
|
||||
return total
|
||||
|
||||
|
||||
def aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]:
|
||||
totals: Dict[str, float] = {}
|
||||
for item in items:
|
||||
for key, value in item.items():
|
||||
if isinstance(value, (int, float)):
|
||||
totals[key] = totals.get(key, 0.0) + float(value)
|
||||
return totals
|
||||
|
||||
|
||||
def build_direct_segment_trace(
|
||||
segment_texts: Sequence[str],
|
||||
prepare_profiles: Sequence[Dict[str, Any]],
|
||||
worker_profiles: Sequence[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
results: List[Dict[str, Any]] = []
|
||||
for index, segment_text in enumerate(segment_texts):
|
||||
prepare_item = prepare_profiles[index] if index < len(prepare_profiles) else {}
|
||||
worker_item = worker_profiles[index] if index < len(worker_profiles) else {}
|
||||
prepare_profile = dict(prepare_item.get("prepare_profile", {}))
|
||||
results.append(
|
||||
{
|
||||
"segment_index": index,
|
||||
"request_id": prepare_item.get("request_id") or worker_item.get("request_id"),
|
||||
"text_len": len(str(segment_text)),
|
||||
"prepare_wall_ms": float(prepare_item.get("prepare_wall_ms", 0.0)),
|
||||
"prepare_profile_total_ms": float(prepare_item.get("prepare_profile_total_ms", 0.0)),
|
||||
"prepare_engine_gpu_queue_wait_ms": float(
|
||||
dict(prepare_item.get("prepare_profile", {})).get("engine_gpu_prepare_queue_wait_ms", 0.0)
|
||||
),
|
||||
"engine_policy_wait_ms": float(prepare_item.get("engine_policy_wait_ms", 0.0)),
|
||||
"engine_dispatch_wait_ms": float(prepare_item.get("engine_dispatch_wait_ms", 0.0)),
|
||||
"decode_admission_wait_ms": float(worker_item.get("decode_admission_wait_ms", 0.0)),
|
||||
"queue_wait_ms": float(worker_item.get("queue_wait_ms", 0.0)),
|
||||
"prefill_ms": float(worker_item.get("prefill_ms", 0.0)),
|
||||
"merge_ms": float(worker_item.get("merge_ms", 0.0)),
|
||||
"decode_ms": float(worker_item.get("decode_ms", 0.0)),
|
||||
"finalize_wait_ms": float(worker_item.get("finalize_wait_ms", 0.0)),
|
||||
"synth_ms": float(worker_item.get("synth_ms", 0.0)),
|
||||
"worker_total_ms": float(worker_item.get("worker_total_ms", 0.0)),
|
||||
"decode_steps": int(worker_item.get("decode_steps", 0)),
|
||||
"semantic_len": int(worker_item.get("semantic_len", 0)),
|
||||
"finish_reason": worker_item.get("finish_reason"),
|
||||
"norm_text": prepare_profile.get("norm_text"),
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def build_direct_scheduler_profile(
|
||||
*,
|
||||
backend: str,
|
||||
request_start: float,
|
||||
response_ready_at: float,
|
||||
audio_bytes: int,
|
||||
sample_rate: int,
|
||||
segment_texts: Sequence[str],
|
||||
prepare_profiles: Sequence[Dict[str, Any]],
|
||||
worker_profiles: Sequence[Dict[str, Any]],
|
||||
pack_ms: float,
|
||||
response_overhead_ms: float,
|
||||
) -> Dict[str, Any]:
|
||||
segment_trace = build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles)
|
||||
prepare_profile_dicts = [dict(item.get("prepare_profile", {})) for item in prepare_profiles]
|
||||
request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0)
|
||||
prepare_wall_ms = sum_profile_field(prepare_profiles, "prepare_wall_ms")
|
||||
prepare_profile_total_ms = sum_profile_field(prepare_profiles, "prepare_profile_total_ms")
|
||||
engine_policy_wait_ms = sum_profile_field(prepare_profiles, "engine_policy_wait_ms")
|
||||
engine_dispatch_wait_ms = sum_profile_field(prepare_profiles, "engine_dispatch_wait_ms")
|
||||
decode_admission_wait_ms = sum_profile_field(worker_profiles, "decode_admission_wait_ms")
|
||||
queue_wait_ms = sum_profile_field(worker_profiles, "queue_wait_ms")
|
||||
prefill_ms = sum_profile_field(worker_profiles, "prefill_ms")
|
||||
merge_ms = sum_profile_field(worker_profiles, "merge_ms")
|
||||
decode_ms = sum_profile_field(worker_profiles, "decode_ms")
|
||||
finalize_wait_ms = sum_profile_field(worker_profiles, "finalize_wait_ms")
|
||||
synth_ms = sum_profile_field(worker_profiles, "synth_ms")
|
||||
worker_total_ms = sum_profile_field(worker_profiles, "worker_total_ms")
|
||||
decode_steps = sum(int(item.get("decode_steps", 0)) for item in worker_profiles)
|
||||
semantic_len = sum(int(item.get("semantic_len", 0)) for item in worker_profiles)
|
||||
request_other_ms = max(
|
||||
0.0,
|
||||
request_total_ms - prepare_wall_ms - engine_policy_wait_ms - worker_total_ms - pack_ms - response_overhead_ms,
|
||||
)
|
||||
return {
|
||||
"backend": backend,
|
||||
"backend_mode": backend,
|
||||
"segment_count": len(segment_texts),
|
||||
"sample_rate": int(sample_rate),
|
||||
"audio_bytes": int(audio_bytes),
|
||||
"request_total_ms": request_total_ms,
|
||||
"prepare_ms": prepare_wall_ms,
|
||||
"prepare_wall_ms": prepare_wall_ms,
|
||||
"prepare_profile_total_ms": prepare_profile_total_ms,
|
||||
"engine_policy_wait_ms": engine_policy_wait_ms,
|
||||
"engine_dispatch_wait_ms": engine_dispatch_wait_ms,
|
||||
"decode_admission_wait_ms": decode_admission_wait_ms,
|
||||
"queue_wait_ms": queue_wait_ms,
|
||||
"prefill_ms": prefill_ms,
|
||||
"merge_ms": merge_ms,
|
||||
"decode_ms": decode_ms,
|
||||
"finalize_wait_ms": finalize_wait_ms,
|
||||
"synth_ms": synth_ms,
|
||||
"pack_ms": pack_ms,
|
||||
"response_overhead_ms": response_overhead_ms,
|
||||
"worker_total_ms": worker_total_ms,
|
||||
"request_other_ms": request_other_ms,
|
||||
"decode_steps": decode_steps,
|
||||
"semantic_len": semantic_len,
|
||||
"prepare_segments": list(prepare_profiles),
|
||||
"worker_segments": list(worker_profiles),
|
||||
"segment_trace": segment_trace,
|
||||
"prepare_aggregate": aggregate_numeric_dicts(prepare_profile_dicts),
|
||||
}
|
||||
|
||||
|
||||
def build_legacy_direct_profile(
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
request_start: float,
|
||||
finished_at: float,
|
||||
sample_rate: int | None = None,
|
||||
audio_bytes: int = 0,
|
||||
pack_ms: float = 0.0,
|
||||
chunk_count: int = 0,
|
||||
stream_total_bytes: int = 0,
|
||||
first_chunk_ms: float | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
request_total_ms = max(0.0, (finished_at - request_start) * 1000.0)
|
||||
legacy_infer_ms = max(0.0, request_total_ms - pack_ms)
|
||||
return {
|
||||
"backend": backend,
|
||||
"backend_mode": backend,
|
||||
"fallback_reason": fallback_reason,
|
||||
"request_total_ms": request_total_ms,
|
||||
"prepare_ms": 0.0,
|
||||
"queue_wait_ms": 0.0,
|
||||
"prefill_ms": 0.0,
|
||||
"merge_ms": 0.0,
|
||||
"decode_ms": 0.0,
|
||||
"finalize_wait_ms": 0.0,
|
||||
"synth_ms": 0.0,
|
||||
"pack_ms": pack_ms,
|
||||
"worker_total_ms": legacy_infer_ms,
|
||||
"request_other_ms": 0.0,
|
||||
"legacy_infer_ms": legacy_infer_ms,
|
||||
"sample_rate": int(sample_rate) if sample_rate is not None else None,
|
||||
"audio_bytes": int(audio_bytes),
|
||||
"chunk_count": int(chunk_count),
|
||||
"stream_total_bytes": int(stream_total_bytes),
|
||||
"first_chunk_ms": None if first_chunk_ms is None else float(first_chunk_ms),
|
||||
}
|
||||
|
||||
|
||||
def build_scheduler_submit_profile(
|
||||
*,
|
||||
backend: str,
|
||||
request_start: float,
|
||||
response_ready_at: float,
|
||||
audio_bytes: int,
|
||||
sample_rate: int,
|
||||
prepare_spec_build_ms: float,
|
||||
prepare_wall_ms: float,
|
||||
prepare_executor_queue_ms: float,
|
||||
prepare_executor_run_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
prepare_profile_wall_ms: float,
|
||||
prepare_other_ms: float,
|
||||
engine_policy_wait_ms: float,
|
||||
api_after_prepare_ms: float,
|
||||
api_wait_result_ms: float,
|
||||
pack_ms: float,
|
||||
response_overhead_ms: float,
|
||||
worker_profile: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
worker_total_ms = float(worker_profile.get("worker_total_ms", 0.0))
|
||||
request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0)
|
||||
request_other_ms = max(
|
||||
0.0,
|
||||
request_total_ms
|
||||
- prepare_wall_ms
|
||||
- engine_policy_wait_ms
|
||||
- api_after_prepare_ms
|
||||
- worker_total_ms
|
||||
- api_wait_result_ms
|
||||
- pack_ms,
|
||||
)
|
||||
result = {
|
||||
"backend": backend,
|
||||
"backend_mode": backend,
|
||||
"audio_bytes": int(audio_bytes),
|
||||
"sample_rate": int(sample_rate),
|
||||
"prepare_spec_build_ms": prepare_spec_build_ms,
|
||||
"prepare_ms": prepare_wall_ms,
|
||||
"prepare_wall_ms": prepare_wall_ms,
|
||||
"prepare_executor_queue_ms": prepare_executor_queue_ms,
|
||||
"prepare_executor_run_ms": prepare_executor_run_ms,
|
||||
"prepare_profile_total_ms": prepare_profile_total_ms,
|
||||
"prepare_profile_wall_ms": prepare_profile_wall_ms,
|
||||
"prepare_other_ms": prepare_other_ms,
|
||||
"engine_policy_wait_ms": float(engine_policy_wait_ms),
|
||||
"api_after_prepare_ms": api_after_prepare_ms,
|
||||
"api_wait_result_ms": api_wait_result_ms,
|
||||
"pack_ms": pack_ms,
|
||||
"response_overhead_ms": response_overhead_ms,
|
||||
"request_total_ms": request_total_ms,
|
||||
"request_other_ms": request_other_ms,
|
||||
}
|
||||
result.update({key: value for key, value in worker_profile.items()})
|
||||
return result
|
||||
|
||||
|
||||
def format_ms_header(value: Any) -> str:
|
||||
return f"{float(value):.3f}"
|
||||
|
||||
|
||||
def build_scheduler_submit_headers(
|
||||
*,
|
||||
request_id: str,
|
||||
media_type: str,
|
||||
sample_rate: int,
|
||||
profile: Dict[str, Any],
|
||||
) -> Dict[str, str]:
|
||||
prepare_profile = dict(profile.get("prepare_profile", {}))
|
||||
headers = {
|
||||
"X-Request-Id": request_id,
|
||||
"X-Semantic-Len": str(int(profile.get("semantic_len", 0))),
|
||||
"X-Finish-Reason": str(profile.get("finish_reason", "unknown")),
|
||||
"X-Queue-Wait-Ms": format_ms_header(profile.get("queue_wait_ms", 0.0)),
|
||||
"X-Decode-Admission-Wait-Ms": format_ms_header(profile.get("decode_admission_wait_ms", 0.0)),
|
||||
"X-Engine-Policy-Wait-Ms": format_ms_header(profile.get("engine_policy_wait_ms", 0.0)),
|
||||
"X-Engine-Dispatch-Wait-Ms": format_ms_header(profile.get("engine_dispatch_wait_ms", 0.0)),
|
||||
"X-Prepare-Ms": format_ms_header(profile.get("prepare_wall_ms", 0.0)),
|
||||
"X-Prepare-Wall-Ms": format_ms_header(profile.get("prepare_wall_ms", 0.0)),
|
||||
"X-Prepare-Spec-Build-Ms": format_ms_header(profile.get("prepare_spec_build_ms", 0.0)),
|
||||
"X-Prepare-Executor-Queue-Ms": format_ms_header(profile.get("prepare_executor_queue_ms", 0.0)),
|
||||
"X-Prepare-Admission-Wait-Ms": format_ms_header(prepare_profile.get("prepare_admission_wait_ms", 0.0)),
|
||||
"X-Prepare-Executor-Run-Ms": format_ms_header(profile.get("prepare_executor_run_ms", 0.0)),
|
||||
"X-Prepare-Profile-Total-Ms": format_ms_header(profile.get("prepare_profile_total_ms", 0.0)),
|
||||
"X-Prepare-Profile-Wall-Ms": format_ms_header(profile.get("prepare_profile_wall_ms", 0.0)),
|
||||
"X-Prepare-Other-Ms": format_ms_header(profile.get("prepare_other_ms", 0.0)),
|
||||
"X-Api-After-Prepare-Ms": format_ms_header(profile.get("api_after_prepare_ms", 0.0)),
|
||||
"X-Prefill-Ms": format_ms_header(profile.get("prefill_ms", 0.0)),
|
||||
"X-Merge-Ms": format_ms_header(profile.get("merge_ms", 0.0)),
|
||||
"X-Decode-Ms": format_ms_header(profile.get("decode_ms", 0.0)),
|
||||
"X-Finalize-Wait-Ms": format_ms_header(profile.get("finalize_wait_ms", 0.0)),
|
||||
"X-Synth-Ms": format_ms_header(profile.get("synth_ms", 0.0)),
|
||||
"X-Worker-Residual-Ms": format_ms_header(profile.get("worker_residual_ms", 0.0)),
|
||||
"X-Worker-Other-Ms": format_ms_header(profile.get("worker_other_ms", 0.0)),
|
||||
"X-Pack-Ms": format_ms_header(profile.get("pack_ms", 0.0)),
|
||||
"X-Worker-Total-Ms": format_ms_header(profile.get("worker_total_ms", 0.0)),
|
||||
"X-Api-Wait-Result-Ms": format_ms_header(profile.get("api_wait_result_ms", 0.0)),
|
||||
"X-Decode-Steps": str(int(profile.get("decode_steps", 0))),
|
||||
"X-Sample-Rate": str(int(sample_rate)),
|
||||
"X-Response-Overhead-Ms": format_ms_header(profile.get("response_overhead_ms", 0.0)),
|
||||
"X-Request-Other-Ms": format_ms_header(profile.get("request_other_ms", 0.0)),
|
||||
"X-Request-Total-Ms": format_ms_header(profile.get("request_total_ms", 0.0)),
|
||||
}
|
||||
headers.update(
|
||||
{
|
||||
"X-Prepare-Prompt-Text-Ms": format_ms_header(prepare_profile.get("prompt_text_features_ms", 0.0)),
|
||||
"X-Prepare-Target-Text-Ms": format_ms_header(prepare_profile.get("text_features_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Text-CPU-Preprocess-Ms": format_ms_header(prepare_profile.get("prompt_text_cpu_preprocess_ms", 0.0)),
|
||||
"X-Prepare-Target-Text-CPU-Preprocess-Ms": format_ms_header(prepare_profile.get("text_cpu_preprocess_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Text-CPU-Queue-Ms": format_ms_header(prepare_profile.get("prompt_text_cpu_queue_ms", 0.0)),
|
||||
"X-Prepare-Target-Text-CPU-Queue-Ms": format_ms_header(prepare_profile.get("text_cpu_queue_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Text-Feature-Queue-Ms": format_ms_header(prepare_profile.get("prompt_text_feature_queue_ms", 0.0)),
|
||||
"X-Prepare-Target-Text-Feature-Queue-Ms": format_ms_header(prepare_profile.get("text_feature_queue_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Bert-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_wait_ms", 0.0)),
|
||||
"X-Prepare-Target-Bert-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_wait_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Bert-Admission-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_admission_wait_ms", 0.0)),
|
||||
"X-Prepare-Target-Bert-Admission-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_admission_wait_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Bert-Queue-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_queue_wait_ms", 0.0)),
|
||||
"X-Prepare-Target-Bert-Queue-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_queue_wait_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_batch_collect_wait_ms", 0.0)),
|
||||
"X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_batch_collect_wait_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Bert-Forward-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_forward_ms", 0.0)),
|
||||
"X-Prepare-Target-Bert-Forward-Ms": format_ms_header(prepare_profile.get("text_bert_forward_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))),
|
||||
"X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))),
|
||||
"X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))),
|
||||
"X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))),
|
||||
"X-Prepare-Prompt-Bert-High-Pressure-Peak": str(int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))),
|
||||
"X-Prepare-Target-Bert-High-Pressure-Peak": str(int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))),
|
||||
"X-Prepare-Prompt-Bert-Batch-Window-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_batch_window_ms", 0.0)),
|
||||
"X-Prepare-Target-Bert-Batch-Window-Ms": format_ms_header(prepare_profile.get("text_bert_batch_window_ms", 0.0)),
|
||||
"X-Prepare-Text-Pair-Wall-Ms": format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)),
|
||||
"X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))),
|
||||
"X-Prepare-Engine-GPU-Queue-Wait-Ms": format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)),
|
||||
"X-Prepare-Audio-Load-Ms": format_ms_header(prepare_profile.get("audio_load_ms", 0.0)),
|
||||
"X-Prepare-Audio-Stage-Wait-Ms": format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Semantic-Ms": format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Semantic-Wait-Ms": format_ms_header(prepare_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Semantic-CPU-Ms": format_ms_header(prepare_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Semantic-Forward-Ms": format_ms_header(prepare_profile.get("prompt_semantic_forward_ms", 0.0)),
|
||||
"X-Prepare-Ref-Spec-Ms": format_ms_header(prepare_profile.get("ref_spec_ms", 0.0)),
|
||||
"X-Prepare-Ref-Spec-Wait-Ms": format_ms_header(prepare_profile.get("ref_spec_wait_ms", 0.0)),
|
||||
"X-Prepare-Ref-Bundle-Ms": format_ms_header(prepare_profile.get("ref_audio_bundle_ms", 0.0)),
|
||||
"X-Prepare-Tensorize-Ms": format_ms_header(prepare_profile.get("tensorize_ms", 0.0)),
|
||||
"X-Prepare-Inflight-On-Enter": str(int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0))),
|
||||
"X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))),
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
|
||||
def build_scheduler_debug_request_profile(
|
||||
*,
|
||||
state: T2SRequestState,
|
||||
item: T2SFinishedItem,
|
||||
batch_request_count: int,
|
||||
prepare_batch_wall_ms: float,
|
||||
decode_batch_wall_ms: float,
|
||||
batch_request_total_ms: float,
|
||||
) -> Dict[str, Any]:
|
||||
prepare_profile = dict(state.prepare_profile)
|
||||
prepare_wall_ms = float(prepare_profile.get("wall_total_ms", 0.0))
|
||||
return {
|
||||
"backend": "scheduler_debug",
|
||||
"backend_mode": "scheduler_debug",
|
||||
"batch_request_count": int(batch_request_count),
|
||||
"batch_prepare_wall_ms": float(prepare_batch_wall_ms),
|
||||
"batch_decode_wall_ms": float(decode_batch_wall_ms),
|
||||
"batch_request_total_ms": float(batch_request_total_ms),
|
||||
"prepare_ms": prepare_wall_ms,
|
||||
"prepare_wall_ms": prepare_wall_ms,
|
||||
"prepare_profile_total_ms": float(prepare_profile.get("wall_total_ms", prepare_wall_ms)),
|
||||
"prepare_profile": prepare_profile,
|
||||
"decode_steps": int(item.finish_idx),
|
||||
"finish_idx": int(item.finish_idx),
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
"finish_reason": item.finish_reason,
|
||||
"norm_text": state.norm_text,
|
||||
"norm_prompt_text": state.norm_prompt_text,
|
||||
}
|
||||
|
||||
|
||||
def build_scheduler_debug_batch_profile(
|
||||
*,
|
||||
request_count: int,
|
||||
max_steps: int,
|
||||
prepare_batch_wall_ms: float,
|
||||
decode_batch_wall_ms: float,
|
||||
request_total_ms: float,
|
||||
finished_items: Sequence[T2SFinishedItem],
|
||||
) -> Dict[str, Any]:
|
||||
finish_reason_counts: Dict[str, int] = {}
|
||||
total_semantic_len = 0
|
||||
for item in finished_items:
|
||||
finish_reason_counts[item.finish_reason] = finish_reason_counts.get(item.finish_reason, 0) + 1
|
||||
total_semantic_len += int(item.semantic_tokens.shape[0])
|
||||
return {
|
||||
"request_count": int(request_count),
|
||||
"max_steps": int(max_steps),
|
||||
"prepare_batch_wall_ms": float(prepare_batch_wall_ms),
|
||||
"decode_batch_wall_ms": float(decode_batch_wall_ms),
|
||||
"request_total_ms": float(request_total_ms),
|
||||
"total_semantic_len": int(total_semantic_len),
|
||||
"finish_reason_counts": finish_reason_counts,
|
||||
}
|
||||
199
GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py
Normal file
199
GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py
Normal file
@ -0,0 +1,199 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import NormalizedEngineRequest, ReferenceRegistry
|
||||
|
||||
|
||||
def normalize_lang(value: str | None) -> str | None:
|
||||
if value in [None, ""]:
|
||||
return value
|
||||
return str(value).lower()
|
||||
|
||||
|
||||
def apply_default_reference(reference_registry: ReferenceRegistry, req: dict) -> dict:
|
||||
normalized = dict(req)
|
||||
default_ref = reference_registry.get_default()
|
||||
if normalized.get("ref_audio_path") in [None, ""] and default_ref.ref_audio_path not in [None, ""]:
|
||||
normalized["ref_audio_path"] = default_ref.ref_audio_path
|
||||
if "text_lang" in normalized:
|
||||
normalized["text_lang"] = normalize_lang(normalized.get("text_lang"))
|
||||
if "prompt_lang" in normalized:
|
||||
normalized["prompt_lang"] = normalize_lang(normalized.get("prompt_lang"))
|
||||
return normalized
|
||||
|
||||
|
||||
def check_params(tts: TTS, cut_method_names: Sequence[str], req: dict) -> Optional[str]:
|
||||
text = req.get("text", "")
|
||||
text_lang = req.get("text_lang", "")
|
||||
ref_audio_path = req.get("ref_audio_path", "")
|
||||
media_type = req.get("media_type", "wav")
|
||||
prompt_lang = req.get("prompt_lang", "")
|
||||
text_split_method = req.get("text_split_method", "cut5")
|
||||
|
||||
if ref_audio_path in [None, ""]:
|
||||
return "ref_audio_path is required"
|
||||
if text in [None, ""]:
|
||||
return "text is required"
|
||||
if text_lang in [None, ""]:
|
||||
return "text_lang is required"
|
||||
if text_lang.lower() not in tts.configs.languages:
|
||||
return f"text_lang: {text_lang} is not supported in version {tts.configs.version}"
|
||||
if prompt_lang in [None, ""]:
|
||||
return "prompt_lang is required"
|
||||
if prompt_lang.lower() not in tts.configs.languages:
|
||||
return f"prompt_lang: {prompt_lang} is not supported in version {tts.configs.version}"
|
||||
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
||||
return f"media_type: {media_type} is not supported"
|
||||
if text_split_method not in cut_method_names:
|
||||
return f"text_split_method:{text_split_method} is not supported"
|
||||
return None
|
||||
|
||||
|
||||
def base_request_defaults() -> Dict[str, Any]:
|
||||
return {
|
||||
"request_id": None,
|
||||
"text": None,
|
||||
"text_lang": None,
|
||||
"ref_audio_path": None,
|
||||
"aux_ref_audio_paths": None,
|
||||
"prompt_text": "",
|
||||
"prompt_lang": None,
|
||||
"top_k": 15,
|
||||
"top_p": 1.0,
|
||||
"temperature": 1.0,
|
||||
"text_split_method": "cut5",
|
||||
"batch_size": 1,
|
||||
"batch_threshold": 0.75,
|
||||
"speed_factor": 1.0,
|
||||
"split_bucket": False,
|
||||
"fragment_interval": 0.3,
|
||||
"seed": -1,
|
||||
"media_type": "wav",
|
||||
"streaming_mode": False,
|
||||
"return_fragment": False,
|
||||
"fixed_length_chunk": False,
|
||||
"response_streaming": False,
|
||||
"parallel_infer": False,
|
||||
"repetition_penalty": 1.35,
|
||||
"sample_steps": 32,
|
||||
"super_sampling": False,
|
||||
"overlap_length": 2,
|
||||
"min_chunk_length": 16,
|
||||
"early_stop_num": -1,
|
||||
"ready_step": 0,
|
||||
"timeout_sec": None,
|
||||
}
|
||||
|
||||
|
||||
def normalize_streaming_mode(req: dict) -> dict:
|
||||
normalized = dict(req)
|
||||
streaming_mode = normalized.get("streaming_mode", False)
|
||||
return_fragment = normalized.get("return_fragment", False)
|
||||
if streaming_mode is False:
|
||||
normalized["streaming_mode"] = False
|
||||
normalized["return_fragment"] = False
|
||||
normalized["fixed_length_chunk"] = False
|
||||
elif streaming_mode == 0:
|
||||
normalized["streaming_mode"] = False
|
||||
normalized["return_fragment"] = False
|
||||
normalized["fixed_length_chunk"] = False
|
||||
elif streaming_mode == 1 or streaming_mode is True:
|
||||
normalized["streaming_mode"] = False
|
||||
normalized["return_fragment"] = True
|
||||
normalized["fixed_length_chunk"] = False
|
||||
elif streaming_mode == 2:
|
||||
normalized["streaming_mode"] = True
|
||||
normalized["return_fragment"] = False
|
||||
normalized["fixed_length_chunk"] = False
|
||||
elif streaming_mode == 3:
|
||||
normalized["streaming_mode"] = True
|
||||
normalized["return_fragment"] = False
|
||||
normalized["fixed_length_chunk"] = True
|
||||
else:
|
||||
raise ValueError("the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)")
|
||||
normalized["response_streaming"] = bool(normalized["streaming_mode"] or normalized["return_fragment"] or return_fragment)
|
||||
return normalized
|
||||
|
||||
|
||||
def is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool:
|
||||
return aux_ref_audio_paths not in [None, [], ()]
|
||||
|
||||
|
||||
def select_direct_backend(normalized: NormalizedEngineRequest) -> Tuple[str, str | None]:
|
||||
if normalized.response_streaming:
|
||||
if normalized.return_fragment or normalized.fixed_length_chunk:
|
||||
return "legacy_direct_fragment", "fragment_streaming_mode"
|
||||
return "legacy_direct_streaming", "streaming_mode"
|
||||
if is_aux_ref_enabled(normalized.aux_ref_audio_paths):
|
||||
return "legacy_direct_aux_ref", "aux_ref_audio_paths"
|
||||
if normalized.super_sampling:
|
||||
return "legacy_direct_super_sampling", "super_sampling"
|
||||
if normalized.prompt_text in [None, ""]:
|
||||
return "legacy_direct_missing_prompt", "missing_prompt_text"
|
||||
return "scheduler_v1_direct", None
|
||||
|
||||
|
||||
def normalize_engine_request(
|
||||
*,
|
||||
tts: TTS,
|
||||
cut_method_names: Sequence[str],
|
||||
reference_registry: ReferenceRegistry,
|
||||
payload: dict | NormalizedEngineRequest,
|
||||
request_id: str | None = None,
|
||||
normalize_streaming: bool = False,
|
||||
error_prefix: str = "request 参数非法: ",
|
||||
) -> NormalizedEngineRequest:
|
||||
if isinstance(payload, NormalizedEngineRequest):
|
||||
normalized_payload = payload.to_payload()
|
||||
else:
|
||||
normalized_payload = base_request_defaults()
|
||||
normalized_payload.update(dict(payload))
|
||||
if request_id not in [None, ""]:
|
||||
normalized_payload["request_id"] = str(request_id)
|
||||
elif normalized_payload.get("request_id") in [None, ""]:
|
||||
raise ValueError("request_id is required after normalization")
|
||||
normalized_payload = apply_default_reference(reference_registry, normalized_payload)
|
||||
if normalize_streaming:
|
||||
normalized_payload = normalize_streaming_mode(normalized_payload)
|
||||
error = check_params(tts, cut_method_names, normalized_payload)
|
||||
if error is not None:
|
||||
raise ValueError(f"{error_prefix}{error}")
|
||||
timeout_sec = normalized_payload.get("timeout_sec")
|
||||
parsed_timeout = None if timeout_sec in [None, ""] else float(timeout_sec)
|
||||
aux_ref_audio_paths = normalized_payload.get("aux_ref_audio_paths")
|
||||
normalized_aux_ref_audio_paths = None if aux_ref_audio_paths in [None, "", []] else [str(item) for item in aux_ref_audio_paths]
|
||||
return NormalizedEngineRequest(
|
||||
request_id=str(normalized_payload["request_id"]),
|
||||
text=str(normalized_payload["text"]),
|
||||
text_lang=str(normalized_payload["text_lang"]),
|
||||
ref_audio_path=str(normalized_payload["ref_audio_path"]),
|
||||
prompt_lang=str(normalized_payload["prompt_lang"]),
|
||||
prompt_text="" if normalized_payload.get("prompt_text") is None else str(normalized_payload.get("prompt_text")),
|
||||
aux_ref_audio_paths=normalized_aux_ref_audio_paths,
|
||||
top_k=int(normalized_payload["top_k"]),
|
||||
top_p=float(normalized_payload["top_p"]),
|
||||
temperature=float(normalized_payload["temperature"]),
|
||||
repetition_penalty=float(normalized_payload["repetition_penalty"]),
|
||||
early_stop_num=int(normalized_payload.get("early_stop_num", -1)),
|
||||
ready_step=int(normalized_payload.get("ready_step", 0)),
|
||||
text_split_method=str(normalized_payload["text_split_method"]),
|
||||
batch_size=int(normalized_payload["batch_size"]),
|
||||
batch_threshold=float(normalized_payload["batch_threshold"]),
|
||||
split_bucket=bool(normalized_payload["split_bucket"]),
|
||||
speed_factor=float(normalized_payload["speed_factor"]),
|
||||
fragment_interval=float(normalized_payload["fragment_interval"]),
|
||||
seed=int(normalized_payload["seed"]),
|
||||
media_type=str(normalized_payload["media_type"]),
|
||||
streaming_mode=normalized_payload["streaming_mode"],
|
||||
return_fragment=bool(normalized_payload.get("return_fragment", False)),
|
||||
fixed_length_chunk=bool(normalized_payload.get("fixed_length_chunk", False)),
|
||||
response_streaming=bool(normalized_payload.get("response_streaming", False)),
|
||||
parallel_infer=bool(normalized_payload["parallel_infer"]),
|
||||
sample_steps=int(normalized_payload["sample_steps"]),
|
||||
super_sampling=bool(normalized_payload["super_sampling"]),
|
||||
overlap_length=int(normalized_payload["overlap_length"]),
|
||||
min_chunk_length=int(normalized_payload["min_chunk_length"]),
|
||||
timeout_sec=parsed_timeout,
|
||||
)
|
||||
283
GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py
Normal file
283
GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py
Normal file
@ -0,0 +1,283 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState, run_scheduler_continuous
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, set_scheduler_seed
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerSubmitExecution
|
||||
|
||||
|
||||
class EngineApiSchedulerFlow:
|
||||
def __init__(self, api: Any) -> None:
|
||||
self.api = api
|
||||
|
||||
def _build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]:
|
||||
specs: List[SchedulerRequestSpec] = []
|
||||
for index, payload in enumerate(request_items):
|
||||
normalized = self.api._normalize_engine_request(
|
||||
payload,
|
||||
request_id=str(payload.get("request_id") or f"req_{index:03d}"),
|
||||
error_prefix=f"request[{index}] 参数非法: ",
|
||||
)
|
||||
specs.append(normalized.to_scheduler_spec())
|
||||
return specs
|
||||
|
||||
def _build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec:
|
||||
normalized = self.api._normalize_engine_request(
|
||||
payload,
|
||||
request_id=(
|
||||
payload.request_id
|
||||
if isinstance(payload, NormalizedEngineRequest)
|
||||
else str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}")
|
||||
),
|
||||
)
|
||||
return normalized.to_scheduler_spec()
|
||||
|
||||
@staticmethod
|
||||
def _summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"request_id": state.request_id,
|
||||
"ready_step": int(state.ready_step),
|
||||
"ref_audio_path": str(state.ref_audio_path),
|
||||
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
|
||||
"all_phone_len": int(state.all_phones.shape[0]),
|
||||
"bert_len": int(state.all_bert_features.shape[-1]),
|
||||
"norm_text": state.norm_text,
|
||||
}
|
||||
for state in states
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"request_id": item.request_id,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
"finish_idx": int(item.finish_idx),
|
||||
"finish_reason": item.finish_reason,
|
||||
}
|
||||
for item in items
|
||||
]
|
||||
|
||||
async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
|
||||
request_start = time.perf_counter()
|
||||
set_scheduler_seed(seed)
|
||||
specs = self._build_scheduler_request_specs(request_items)
|
||||
request_ids = [spec.request_id for spec in specs]
|
||||
for spec in specs:
|
||||
self.api._register_request_state(
|
||||
request_id=spec.request_id,
|
||||
api_mode="scheduler_debug",
|
||||
backend="scheduler_debug",
|
||||
media_type="wav",
|
||||
response_streaming=False,
|
||||
meta={
|
||||
"text_len": len(spec.text),
|
||||
"prompt_text_len": len(spec.prompt_text),
|
||||
"text_lang": spec.text_lang,
|
||||
"prompt_lang": spec.prompt_lang,
|
||||
"ref_audio_path": str(spec.ref_audio_path),
|
||||
"ready_step": int(spec.ready_step),
|
||||
},
|
||||
)
|
||||
self.api._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"})
|
||||
self.api._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, None)
|
||||
prepare_started_at = time.perf_counter()
|
||||
try:
|
||||
states = await self.api.scheduler_worker.prepare_states_batch_async(specs)
|
||||
except Exception as exc:
|
||||
for request_id in request_ids:
|
||||
self.api._fail_request_state(request_id, str(exc))
|
||||
raise
|
||||
prepare_finished_at = time.perf_counter()
|
||||
prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0)
|
||||
for state in states:
|
||||
self.api._update_request_state(
|
||||
state.request_id,
|
||||
EngineStatus.ACTIVE_DECODE,
|
||||
{
|
||||
"prepare_profile": dict(state.prepare_profile),
|
||||
"norm_text": state.norm_text,
|
||||
"norm_prompt_text": state.norm_prompt_text,
|
||||
},
|
||||
)
|
||||
decode_started_at = time.perf_counter()
|
||||
try:
|
||||
finished = run_scheduler_continuous(self.api.tts.t2s_model.model, states, max_steps=int(max_steps))
|
||||
except Exception as exc:
|
||||
for request_id in request_ids:
|
||||
self.api._fail_request_state(request_id, str(exc))
|
||||
raise
|
||||
decode_finished_at = time.perf_counter()
|
||||
decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0)
|
||||
request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0)
|
||||
finished_map = {item.request_id: item for item in finished}
|
||||
request_profiles: List[Dict[str, Any]] = []
|
||||
for state in states:
|
||||
item = finished_map.get(state.request_id)
|
||||
if item is None:
|
||||
self.api._fail_request_state(state.request_id, "scheduler_debug finished without result")
|
||||
continue
|
||||
request_profile = self.api._build_scheduler_debug_request_profile(
|
||||
state=state,
|
||||
item=item,
|
||||
batch_request_count=len(states),
|
||||
prepare_batch_wall_ms=prepare_batch_wall_ms,
|
||||
decode_batch_wall_ms=decode_batch_wall_ms,
|
||||
batch_request_total_ms=request_total_ms,
|
||||
)
|
||||
request_profiles.append(
|
||||
{
|
||||
"request_id": state.request_id,
|
||||
"profile": dict(request_profile),
|
||||
}
|
||||
)
|
||||
self.api._complete_request_state(
|
||||
state.request_id,
|
||||
dict(request_profile),
|
||||
)
|
||||
return SchedulerDebugExecution(
|
||||
payload={
|
||||
"message": "success",
|
||||
"request_count": len(states),
|
||||
"max_steps": int(max_steps),
|
||||
"batch_profile": self.api._build_scheduler_debug_batch_profile(
|
||||
request_count=len(states),
|
||||
max_steps=int(max_steps),
|
||||
prepare_batch_wall_ms=prepare_batch_wall_ms,
|
||||
decode_batch_wall_ms=decode_batch_wall_ms,
|
||||
request_total_ms=request_total_ms,
|
||||
finished_items=finished,
|
||||
),
|
||||
"requests": self._summarize_scheduler_states(states),
|
||||
"finished": self._summarize_scheduler_finished(finished),
|
||||
"request_profiles": request_profiles,
|
||||
"request_traces": self.api._collect_request_summaries(request_ids),
|
||||
}
|
||||
)
|
||||
|
||||
async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution:
|
||||
request_start = time.perf_counter()
|
||||
prepare_start = request_start
|
||||
normalized = self.api._normalize_engine_request(
|
||||
payload,
|
||||
request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"),
|
||||
)
|
||||
spec = self._build_scheduler_submit_spec(normalized)
|
||||
deadline_ts = None
|
||||
timeout_sec = normalized.timeout_sec
|
||||
if timeout_sec is not None:
|
||||
try:
|
||||
deadline_ts = request_start + float(timeout_sec)
|
||||
except Exception:
|
||||
deadline_ts = None
|
||||
self.api._register_request_state(
|
||||
request_id=spec.request_id,
|
||||
api_mode="scheduler_submit",
|
||||
backend="scheduler_v1",
|
||||
media_type=normalized.media_type,
|
||||
response_streaming=False,
|
||||
deadline_ts=deadline_ts,
|
||||
meta=self.api._build_request_meta(normalized.to_payload()),
|
||||
)
|
||||
self.api._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_submit"})
|
||||
spec_ready_at = time.perf_counter()
|
||||
prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0)
|
||||
self.api._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, {"prepare_spec_build_ms": prepare_spec_build_ms})
|
||||
try:
|
||||
state, prepare_exec_started_at, prepare_exec_finished_at = await self.api._prepare_state_via_engine_gpu_queue(
|
||||
spec=spec,
|
||||
prepare_submit_at=spec_ready_at,
|
||||
engine_request_id=spec.request_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
self.api._fail_request_state(spec.request_id, str(exc))
|
||||
raise
|
||||
prepare_wall_ms = max(0.0, (prepare_exec_finished_at - spec_ready_at) * 1000.0)
|
||||
prepare_executor_queue_ms = max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0)
|
||||
prepare_executor_run_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0)
|
||||
prepare_profile = dict(state.prepare_profile)
|
||||
prepare_profile_total_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms))
|
||||
prepare_profile_wall_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms))
|
||||
prepare_other_ms = max(0.0, prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_executor_run_ms)
|
||||
self.api._update_request_state(
|
||||
spec.request_id,
|
||||
EngineStatus.READY_FOR_PREFILL,
|
||||
{
|
||||
"prepare_wall_ms": prepare_wall_ms,
|
||||
"prepare_profile_total_ms": prepare_profile_total_ms,
|
||||
"prepare_profile": prepare_profile,
|
||||
},
|
||||
)
|
||||
api_after_prepare_start = time.perf_counter()
|
||||
loop = asyncio.get_running_loop()
|
||||
done_future = loop.create_future()
|
||||
await self.api._enqueue_prepared_state_for_dispatch(
|
||||
state=state,
|
||||
speed_factor=float(normalized.speed_factor),
|
||||
sample_steps=int(normalized.sample_steps),
|
||||
media_type=normalized.media_type,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=spec.request_id,
|
||||
timeout_sec=normalized.timeout_sec,
|
||||
)
|
||||
api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0)
|
||||
try:
|
||||
job = await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0))
|
||||
except Exception as exc:
|
||||
self.api._fail_request_state(spec.request_id, str(exc))
|
||||
raise
|
||||
wait_return_at = time.perf_counter()
|
||||
if job.error is not None:
|
||||
raise RuntimeError(job.error)
|
||||
if job.audio_data is None or job.sample_rate is None or job.result is None:
|
||||
self.api._fail_request_state(spec.request_id, f"{job.request_id} finished without audio result")
|
||||
raise RuntimeError(f"{job.request_id} finished without audio result")
|
||||
pack_start = time.perf_counter()
|
||||
audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue()
|
||||
pack_end = time.perf_counter()
|
||||
pack_ms = (pack_end - pack_start) * 1000.0
|
||||
api_wait_result_ms = 0.0
|
||||
if job.result_ready_time is not None:
|
||||
api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0)
|
||||
response_ready_at = time.perf_counter()
|
||||
response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0)
|
||||
submit_profile = self.api._build_scheduler_submit_profile(
|
||||
backend="scheduler_v1",
|
||||
request_start=request_start,
|
||||
response_ready_at=response_ready_at,
|
||||
audio_bytes=len(audio_data),
|
||||
sample_rate=int(job.sample_rate),
|
||||
prepare_spec_build_ms=prepare_spec_build_ms,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_executor_queue_ms=prepare_executor_queue_ms,
|
||||
prepare_executor_run_ms=prepare_executor_run_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
prepare_profile_wall_ms=prepare_profile_wall_ms,
|
||||
prepare_other_ms=prepare_other_ms,
|
||||
engine_policy_wait_ms=float(job.result.get("engine_policy_wait_ms", 0.0)),
|
||||
api_after_prepare_ms=api_after_prepare_ms,
|
||||
api_wait_result_ms=api_wait_result_ms,
|
||||
pack_ms=pack_ms,
|
||||
response_overhead_ms=response_overhead_ms,
|
||||
worker_profile=dict(job.result or {}),
|
||||
)
|
||||
headers = self.api._build_scheduler_submit_headers(
|
||||
request_id=job.request_id,
|
||||
media_type=job.media_type,
|
||||
sample_rate=int(job.sample_rate),
|
||||
profile=submit_profile,
|
||||
)
|
||||
self.api._merge_request_state_profile(
|
||||
spec.request_id,
|
||||
dict(submit_profile, response_headers_emitted=True),
|
||||
)
|
||||
return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=f"audio/{job.media_type}", headers=headers)
|
||||
Loading…
x
Reference in New Issue
Block a user