GPT-SoVITS/GPT_SoVITS/TTS_infer_pack/unified_engine_api_direct.py
baicai-1145 6a822b28c3 Enhance TTS API with improved request handling and asynchronous processing
Refactor api_v2.py and api_v3.py to update sampling parameters and weight paths for better clarity and support for v3/v4 vocoders. Introduce new methods in PrepareCoordinator for handling empty text features and improve profiling capabilities. Additionally, update unified engine components to streamline audio processing and state management, enhancing overall performance and maintainability of the TTS system.
2026-03-12 01:27:19 +08:00

596 lines
26 KiB
Python

from __future__ import annotations
import asyncio
import queue
import threading
import time
import uuid
from io import BytesIO
from typing import Any, Dict, Generator, List, Optional
import numpy as np
from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, wave_header_chunk
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, EngineStatus, NormalizedEngineRequest, SchedulerPendingJob
class EngineApiDirectFlow:
def __init__(self, api: Any) -> None:
self.api = api
def _iter_legacy_direct_tts_bytes(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> Generator[bytes, None, None]:
payload = normalized.to_payload()
media_type = normalized.media_type
request_id = normalized.request_id
request_start = time.perf_counter()
chunk_count = 0
stream_total_bytes = 0
first_chunk_ms: float | None = None
self.api._update_request_state(
request_id,
EngineStatus.ACTIVE_DECODE,
{"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason},
)
try:
with self.api.direct_tts_lock:
tts_generator = self.api.tts.run(payload)
first_chunk = True
current_media_type = media_type
for sr, chunk in tts_generator:
if first_chunk:
first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0)
self.api._update_request_state(
request_id,
EngineStatus.STREAMING,
{
"backend": backend,
"backend_mode": backend,
"fallback_reason": fallback_reason,
"sample_rate": int(sr),
},
)
if first_chunk and media_type == "wav":
header = wave_header_chunk(sample_rate=sr)
chunk_count += 1
stream_total_bytes += len(header)
yield header
current_media_type = "raw"
first_chunk = False
elif first_chunk:
first_chunk = False
packed_chunk = pack_audio(BytesIO(), chunk, sr, current_media_type).getvalue()
chunk_count += 1
stream_total_bytes += len(packed_chunk)
yield packed_chunk
except Exception as exc:
self.api._fail_request_state(request_id, str(exc))
raise
self.api._complete_request_state(
request_id,
dict(
self.api._build_legacy_direct_profile(
backend=backend,
fallback_reason=fallback_reason,
request_start=request_start,
finished_at=time.perf_counter(),
audio_bytes=stream_total_bytes,
chunk_count=chunk_count,
stream_total_bytes=stream_total_bytes,
first_chunk_ms=first_chunk_ms,
),
streaming_completed=True,
),
)
def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool:
if isinstance(req, NormalizedEngineRequest):
normalized = req
else:
normalized = self.api._normalize_engine_request(
req,
request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"),
normalize_streaming=True,
)
backend, _ = self.api._select_direct_backend(normalized)
return backend == "scheduler_v1_direct"
def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]:
payload = normalized.to_payload() if isinstance(normalized, NormalizedEngineRequest) else normalized
return self.api.tts.text_preprocessor.pre_seg_text(
str(payload["text"]),
str(payload["text_lang"]),
str(payload.get("text_split_method", "cut5")),
)
def _build_segment_request(
self,
normalized: NormalizedEngineRequest,
*,
request_id: str,
text: str,
) -> NormalizedEngineRequest:
payload = normalized.to_payload()
payload["request_id"] = request_id
payload["text"] = text
payload["streaming_mode"] = False
payload["return_fragment"] = False
payload["fixed_length_chunk"] = False
payload["response_streaming"] = False
return self.api._normalize_engine_request(payload, error_prefix="segment request 参数非法: ")
async def _execute_single_segment_scheduler_job(
self,
normalized: NormalizedEngineRequest,
*,
segment_request: NormalizedEngineRequest,
) -> tuple[SchedulerPendingJob, Dict[str, Any]]:
spec = self.api._build_scheduler_submit_spec(segment_request)
state, prepare_exec_started_at, prepare_exec_finished_at = await self.api._prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=time.perf_counter(),
engine_request_id=None,
)
prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0)
prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms))
loop = asyncio.get_running_loop()
done_future = loop.create_future()
await self.api._enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=float(normalized.speed_factor),
sample_steps=int(normalized.sample_steps),
media_type=normalized.media_type,
super_sampling=bool(normalized.super_sampling),
prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=loop,
done_future=done_future,
engine_request_id=None,
timeout_sec=normalized.timeout_sec,
)
timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)
job: SchedulerPendingJob = await asyncio.wait_for(done_future, timeout=timeout_sec)
return job, {
"request_id": spec.request_id,
"prepare_wall_ms": prepare_wall_ms,
"prepare_profile_total_ms": prepare_profile_total_ms,
"prepare_profile": dict(state.prepare_profile),
}
def _iter_scheduler_direct_tts_bytes(self, normalized: NormalizedEngineRequest) -> Generator[bytes, None, None]:
request_start = time.perf_counter()
request_id = normalized.request_id
media_type = normalized.media_type
segment_texts = self._segment_direct_text(normalized)
if not segment_texts:
raise ValueError("text preprocessing returned no valid segments")
chunk_queue: queue.Queue[object] = queue.Queue(maxsize=8)
done_marker = object()
async def _produce_chunks() -> None:
self.api._update_request_state(
request_id,
EngineStatus.CPU_PREPARING,
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)},
)
sample_rate: int | None = None
current_media_type = media_type
chunk_count = 0
stream_total_bytes = 0
first_chunk_ms: float | None = None
prepare_profiles: List[Dict[str, Any]] = []
worker_profiles: List[Dict[str, Any]] = []
try:
for segment_index, segment_text in enumerate(segment_texts):
segment_request = self._build_segment_request(
normalized,
request_id=f"{request_id}_seg_{segment_index:03d}",
text=segment_text,
)
self.api._update_request_state(
request_id,
EngineStatus.READY_FOR_PREFILL,
{
"backend": "scheduler_v1_direct",
"backend_mode": "scheduler_v1_direct",
"segment_index": segment_index,
"segment_count": len(segment_texts),
},
)
job, prepare_profile = await self._execute_single_segment_scheduler_job(
normalized,
segment_request=segment_request,
)
prepare_profiles.append(prepare_profile)
if job.error is not None:
raise RuntimeError(job.error)
if job.audio_data is None or job.sample_rate is None or job.result is None:
raise RuntimeError(f"{job.request_id} finished without audio result")
worker_profiles.append(dict(job.result))
if sample_rate is None:
sample_rate = int(job.sample_rate)
first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0)
self.api._update_request_state(
request_id,
EngineStatus.STREAMING,
{
"backend": "scheduler_v1_direct",
"backend_mode": "scheduler_v1_direct",
"sample_rate": int(sample_rate),
},
)
if media_type == "wav":
header = wave_header_chunk(sample_rate=int(sample_rate))
chunk_count += 1
stream_total_bytes += len(header)
chunk_queue.put(header)
current_media_type = "raw"
packed_chunk = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), current_media_type).getvalue()
chunk_count += 1
stream_total_bytes += len(packed_chunk)
chunk_queue.put(packed_chunk)
if segment_index + 1 < len(segment_texts):
silence_samples = int(float(normalized.fragment_interval) * float(job.sample_rate))
if silence_samples > 0:
silence_chunk = np.zeros(silence_samples, dtype=np.int16)
packed_silence = pack_audio(
BytesIO(), silence_chunk, int(job.sample_rate), current_media_type
).getvalue()
chunk_count += 1
stream_total_bytes += len(packed_silence)
chunk_queue.put(packed_silence)
except Exception as exc:
self.api._fail_request_state(request_id, str(exc))
chunk_queue.put(exc)
else:
self.api._merge_request_state_profile(
request_id,
{
"prepare_aggregate": self.api._aggregate_numeric_dicts(
[item["prepare_profile"] for item in prepare_profiles]
),
"engine_policy_wait_ms": sum(
float(item.get("engine_policy_wait_ms", 0.0)) for item in worker_profiles
),
"engine_dispatch_wait_ms": sum(
float(item.get("engine_dispatch_wait_ms", 0.0)) for item in worker_profiles
),
},
)
direct_profile = self.api._build_direct_scheduler_profile(
backend="scheduler_v1_direct",
request_start=request_start,
response_ready_at=time.perf_counter(),
audio_bytes=stream_total_bytes,
sample_rate=int(sample_rate or 0),
segment_texts=segment_texts,
prepare_profiles=prepare_profiles,
worker_profiles=worker_profiles,
pack_ms=0.0,
response_overhead_ms=0.0,
)
self.api._complete_request_state(
request_id,
dict(direct_profile, streaming_completed=True, first_chunk_ms=first_chunk_ms),
)
finally:
chunk_queue.put(done_marker)
producer_thread = threading.Thread(target=lambda: asyncio.run(_produce_chunks()), daemon=True)
producer_thread.start()
while True:
item = chunk_queue.get()
if item is done_marker:
break
if isinstance(item, Exception):
raise item
yield item
async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution:
request_start = time.perf_counter()
request_id = normalized.request_id
media_type = normalized.media_type
segment_texts = self._segment_direct_text(normalized)
if not segment_texts:
raise ValueError("text preprocessing returned no valid segments")
if normalized.response_streaming:
return DirectTTSExecution(
media_type=media_type,
streaming=True,
audio_generator=self._iter_scheduler_direct_tts_bytes(normalized),
request_id=request_id,
)
self.api._update_request_state(
request_id,
EngineStatus.CPU_PREPARING,
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)},
)
segment_requests = [
self._build_segment_request(
normalized,
request_id=f"{request_id}_seg_{segment_index:03d}",
text=segment_text,
)
for segment_index, segment_text in enumerate(segment_texts)
]
prepare_profiles: List[Dict[str, Any]] = []
loop = asyncio.get_running_loop()
done_futures: List[asyncio.Future] = []
self.api._update_request_state(
request_id,
EngineStatus.READY_FOR_PREFILL,
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_requests)},
)
prepared_items = await asyncio.gather(
*[
self._execute_single_segment_scheduler_job(
normalized,
segment_request=segment_request,
)
for segment_request in segment_requests
]
)
for job, prepare_profile in prepared_items:
prepare_profiles.append(prepare_profile)
done_future = loop.create_future()
done_future.set_result(job)
done_futures.append(done_future)
self.api._update_request_state(
request_id,
EngineStatus.ACTIVE_DECODE,
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"},
)
timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)
jobs: List[SchedulerPendingJob] = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=timeout_sec))
for profile_item, job in zip(prepare_profiles, jobs):
profile_item["engine_policy_wait_ms"] = float(job.engine_policy_wait_ms)
profile_item["engine_dispatch_wait_ms"] = float(job.engine_dispatch_wait_ms)
self.api._merge_request_state_profile(
request_id,
{
"engine_policy_wait_ms": sum(float(job.engine_policy_wait_ms) for job in jobs),
"engine_dispatch_wait_ms": sum(float(job.engine_dispatch_wait_ms) for job in jobs),
"prepare_aggregate": self.api._aggregate_numeric_dicts([item["prepare_profile"] for item in prepare_profiles]),
},
)
sample_rate: int | None = None
audio_parts: List[np.ndarray] = []
worker_profiles: List[Dict[str, Any]] = []
fragment_interval = float(normalized.fragment_interval)
silence_chunk: Optional[np.ndarray] = None
for job in jobs:
if job.error is not None:
raise RuntimeError(job.error)
if job.audio_data is None or job.sample_rate is None or job.result is None:
raise RuntimeError(f"{job.request_id} finished without audio result")
if sample_rate is None:
sample_rate = int(job.sample_rate)
silence_samples = int(fragment_interval * float(sample_rate))
if silence_samples > 0:
silence_chunk = np.zeros(silence_samples, dtype=np.int16)
elif int(job.sample_rate) != sample_rate:
raise RuntimeError("segment sample rate mismatch")
audio_parts.append(job.audio_data)
if silence_chunk is not None:
audio_parts.append(silence_chunk.copy())
worker_profiles.append(dict(job.result))
if sample_rate is None or not audio_parts:
raise RuntimeError("direct scheduler backend produced no audio")
self.api._update_request_state(
request_id,
EngineStatus.FINALIZING,
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"},
)
merged_audio = np.concatenate(audio_parts, axis=0)
pack_start = time.perf_counter()
audio_bytes = pack_audio(BytesIO(), merged_audio, sample_rate, media_type).getvalue()
pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0)
direct_profile = self.api._build_direct_scheduler_profile(
backend="scheduler_v1_direct",
request_start=request_start,
response_ready_at=time.perf_counter(),
audio_bytes=len(audio_bytes),
sample_rate=int(sample_rate),
segment_texts=segment_texts,
prepare_profiles=prepare_profiles,
worker_profiles=worker_profiles,
pack_ms=pack_ms,
response_overhead_ms=0.0,
)
self.api._complete_request_state(
request_id,
dict(direct_profile, streaming_completed=False),
)
return DirectTTSExecution(
media_type=media_type,
streaming=False,
audio_bytes=audio_bytes,
request_id=request_id,
)
def _run_legacy_direct_tts_blocking(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> DirectTTSExecution:
normalized_payload = normalized.to_payload()
request_id = normalized.request_id
media_type = normalized.media_type
request_start = time.perf_counter()
self.api._update_request_state(
request_id,
EngineStatus.ACTIVE_DECODE,
{"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason},
)
with self.api.direct_tts_lock:
tts_generator = self.api.tts.run(normalized_payload)
try:
sr, audio_data = next(tts_generator)
except Exception as exc:
self.api._fail_request_state(request_id, str(exc))
raise
self.api._update_request_state(
request_id,
EngineStatus.FINALIZING,
{"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason},
)
pack_start = time.perf_counter()
packed_audio = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0)
self.api._complete_request_state(
request_id,
dict(
self.api._build_legacy_direct_profile(
backend=backend,
fallback_reason=fallback_reason,
request_start=request_start,
finished_at=time.perf_counter(),
sample_rate=int(sr),
audio_bytes=len(packed_audio),
pack_ms=pack_ms,
),
streaming_completed=False,
),
)
return DirectTTSExecution(
media_type=media_type,
streaming=False,
audio_bytes=packed_audio,
request_id=request_id,
)
async def _run_direct_tts_via_legacy_backend(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> DirectTTSExecution:
if normalized.response_streaming:
return DirectTTSExecution(
media_type=normalized.media_type,
streaming=True,
audio_generator=self._iter_legacy_direct_tts_bytes(
normalized,
backend=backend,
fallback_reason=fallback_reason,
),
request_id=normalized.request_id,
)
return await asyncio.to_thread(
self._run_legacy_direct_tts_blocking,
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution:
normalized = self.api._normalize_engine_request(
req,
request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"),
normalize_streaming=True,
error_prefix="",
)
request_id = normalized.request_id
media_type = normalized.media_type
backend, fallback_reason = self.api._select_direct_backend(normalized)
self.api._register_request_state(
request_id=request_id,
api_mode="tts",
backend=backend,
media_type=media_type,
response_streaming=bool(normalized.response_streaming),
deadline_ts=(time.perf_counter() + float(normalized.timeout_sec) if normalized.timeout_sec is not None else None),
meta=self.api._build_request_meta(normalized.to_payload()),
)
self.api._update_request_state(
request_id,
EngineStatus.VALIDATED,
{
"request_source": "direct_tts",
"selected_backend": backend,
"fallback_reason": fallback_reason,
},
)
if backend == "scheduler_v1_direct":
try:
return await self._run_direct_tts_via_scheduler(normalized)
except Exception as exc:
self.api._fail_request_state(request_id, str(exc))
raise
return await self._run_direct_tts_via_legacy_backend(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
def run_direct_tts(self, req: dict) -> DirectTTSExecution:
normalized = self.api._normalize_engine_request(
req,
request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"),
normalize_streaming=True,
error_prefix="",
)
request_id = normalized.request_id
media_type = normalized.media_type
backend, fallback_reason = self.api._select_direct_backend(normalized)
if not self.api._has_active_request(request_id):
self.api._register_request_state(
request_id=request_id,
api_mode="tts",
backend=backend,
media_type=media_type,
response_streaming=bool(normalized.response_streaming),
meta=self.api._build_request_meta(normalized.to_payload()),
)
self.api._update_request_state(
request_id,
EngineStatus.VALIDATED,
{
"request_source": "direct_tts",
"selected_backend": backend,
"fallback_reason": fallback_reason,
},
)
if backend != "scheduler_v1_direct":
if normalized.response_streaming:
return DirectTTSExecution(
media_type=media_type,
streaming=True,
audio_generator=self._iter_legacy_direct_tts_bytes(
normalized,
backend=backend,
fallback_reason=fallback_reason,
),
request_id=request_id,
)
return self._run_legacy_direct_tts_blocking(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
if normalized.response_streaming:
return DirectTTSExecution(
media_type=media_type,
streaming=True,
audio_generator=self._iter_legacy_direct_tts_bytes(
normalized,
backend="legacy_direct_sync_compat",
fallback_reason="sync_direct_compat",
),
request_id=request_id,
)
return self._run_legacy_direct_tts_blocking(
normalized,
backend="legacy_direct_sync_compat",
fallback_reason="sync_direct_compat",
)