mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-15 06:28:11 +08:00
Refactor api_v2.py and api_v3.py to update sampling parameters and weight paths for better clarity and support for v3/v4 vocoders. Introduce new methods in PrepareCoordinator for handling empty text features and improve profiling capabilities. Additionally, update unified engine components to streamline audio processing and state management, enhancing overall performance and maintainability of the TTS system.
341 lines
17 KiB
Python
341 lines
17 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import time
|
|
import uuid
|
|
from io import BytesIO
|
|
from typing import Any, Dict, List
|
|
|
|
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState
|
|
from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, set_scheduler_seed
|
|
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerSubmitExecution
|
|
|
|
|
|
class EngineApiSchedulerFlow:
|
|
def __init__(self, api: Any) -> None:
|
|
self.api = api
|
|
|
|
def _build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]:
|
|
specs: List[SchedulerRequestSpec] = []
|
|
for index, payload in enumerate(request_items):
|
|
normalized = self.api._normalize_engine_request(
|
|
payload,
|
|
request_id=str(payload.get("request_id") or f"req_{index:03d}"),
|
|
error_prefix=f"request[{index}] 参数非法: ",
|
|
)
|
|
specs.append(normalized.to_scheduler_spec())
|
|
return specs
|
|
|
|
def _build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec:
|
|
normalized = self.api._normalize_engine_request(
|
|
payload,
|
|
request_id=(
|
|
payload.request_id
|
|
if isinstance(payload, NormalizedEngineRequest)
|
|
else str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}")
|
|
),
|
|
)
|
|
return normalized.to_scheduler_spec()
|
|
|
|
@staticmethod
|
|
def _summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]:
|
|
return [
|
|
{
|
|
"request_id": state.request_id,
|
|
"ready_step": int(state.ready_step),
|
|
"ref_audio_path": str(state.ref_audio_path),
|
|
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
|
|
"all_phone_len": int(state.all_phones.shape[0]),
|
|
"bert_len": int(state.all_bert_features.shape[-1]),
|
|
"norm_text": state.norm_text,
|
|
}
|
|
for state in states
|
|
]
|
|
|
|
@staticmethod
|
|
def _summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
|
|
return [
|
|
{
|
|
"request_id": item.request_id,
|
|
"semantic_len": int(item.semantic_tokens.shape[0]),
|
|
"finish_idx": int(item.finish_idx),
|
|
"finish_reason": item.finish_reason,
|
|
}
|
|
for item in items
|
|
]
|
|
|
|
async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
|
|
request_start = time.perf_counter()
|
|
set_scheduler_seed(seed)
|
|
normalized_requests: List[NormalizedEngineRequest] = []
|
|
for index, payload in enumerate(request_items):
|
|
normalized_requests.append(
|
|
self.api._normalize_engine_request(
|
|
payload,
|
|
request_id=str(payload.get("request_id") or f"req_{index:03d}"),
|
|
error_prefix=f"request[{index}] 参数非法: ",
|
|
)
|
|
)
|
|
specs = [normalized.to_scheduler_spec() for normalized in normalized_requests]
|
|
request_ids = [normalized.request_id for normalized in normalized_requests]
|
|
for normalized, spec in zip(normalized_requests, specs):
|
|
self.api._register_request_state(
|
|
request_id=normalized.request_id,
|
|
api_mode="scheduler_debug",
|
|
backend="scheduler_debug",
|
|
media_type=normalized.media_type,
|
|
response_streaming=False,
|
|
meta=self.api._build_request_meta(normalized.to_payload()),
|
|
)
|
|
self.api._update_request_state(normalized.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"})
|
|
self.api._update_request_state(normalized.request_id, EngineStatus.CPU_PREPARING, None)
|
|
prepare_started_at = time.perf_counter()
|
|
original_worker_max_steps = int(self.api.scheduler_worker.max_steps)
|
|
original_decode_max_steps = int(self.api.scheduler_worker.decode_executor.max_steps)
|
|
try:
|
|
self.api.scheduler_worker.max_steps = int(max_steps)
|
|
self.api.scheduler_worker.decode_executor.max_steps = int(max_steps)
|
|
prepared_payloads = await asyncio.gather(
|
|
*[
|
|
self.api._prepare_state_via_engine_gpu_queue(
|
|
spec=spec,
|
|
prepare_submit_at=time.perf_counter(),
|
|
engine_request_id=normalized.request_id,
|
|
)
|
|
for normalized, spec in zip(normalized_requests, specs)
|
|
]
|
|
)
|
|
except Exception as exc:
|
|
for request_id in request_ids:
|
|
self.api._fail_request_state(request_id, str(exc))
|
|
raise
|
|
finally:
|
|
self.api.scheduler_worker.max_steps = int(original_worker_max_steps)
|
|
self.api.scheduler_worker.decode_executor.max_steps = int(original_decode_max_steps)
|
|
prepare_finished_at = time.perf_counter()
|
|
prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0)
|
|
states = [payload[0] for payload in prepared_payloads]
|
|
for state in states:
|
|
self.api._update_request_state(
|
|
state.request_id,
|
|
EngineStatus.READY_FOR_PREFILL,
|
|
{
|
|
"prepare_profile": dict(state.prepare_profile),
|
|
"norm_text": state.norm_text,
|
|
"norm_prompt_text": state.norm_prompt_text,
|
|
},
|
|
)
|
|
decode_started_at = time.perf_counter()
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
done_futures: List[asyncio.Future] = []
|
|
for normalized, state in zip(normalized_requests, states):
|
|
done_future = loop.create_future()
|
|
done_futures.append(done_future)
|
|
await self.api._enqueue_prepared_state_for_dispatch(
|
|
state=state,
|
|
speed_factor=float(normalized.speed_factor),
|
|
sample_steps=int(normalized.sample_steps),
|
|
media_type=normalized.media_type,
|
|
super_sampling=bool(normalized.super_sampling),
|
|
prepare_wall_ms=float(state.prepare_profile.get("wall_total_ms", 0.0)),
|
|
prepare_profile_total_ms=float(state.prepare_profile.get("wall_total_ms", 0.0)),
|
|
done_loop=loop,
|
|
done_future=done_future,
|
|
engine_request_id=normalized.request_id,
|
|
timeout_sec=normalized.timeout_sec,
|
|
)
|
|
timeout_candidates = [float(item.timeout_sec) for item in normalized_requests if item.timeout_sec not in [None, ""]]
|
|
timeout_sec = max(timeout_candidates) if timeout_candidates else 60.0
|
|
jobs = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=float(timeout_sec)))
|
|
except Exception as exc:
|
|
for request_id in request_ids:
|
|
self.api._fail_request_state(request_id, str(exc))
|
|
raise
|
|
decode_finished_at = time.perf_counter()
|
|
decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0)
|
|
request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0)
|
|
request_profiles: List[Dict[str, Any]] = []
|
|
finished: List[Dict[str, Any]] = []
|
|
finish_reason_counts: Dict[str, int] = {}
|
|
total_semantic_len = 0
|
|
for state, job in zip(states, jobs):
|
|
if job.error is not None:
|
|
self.api._fail_request_state(state.request_id, str(job.error))
|
|
raise RuntimeError(str(job.error))
|
|
if job.result is None:
|
|
self.api._fail_request_state(state.request_id, "scheduler_debug finished without result")
|
|
raise RuntimeError(f"{state.request_id} finished without result")
|
|
job_result = dict(job.result)
|
|
request_profile = {
|
|
**job_result,
|
|
"backend": "scheduler_debug",
|
|
"backend_mode": "scheduler_debug",
|
|
"batch_request_count": int(len(states)),
|
|
"batch_prepare_wall_ms": float(prepare_batch_wall_ms),
|
|
"batch_decode_wall_ms": float(decode_batch_wall_ms),
|
|
"batch_request_total_ms": float(request_total_ms),
|
|
"prepare_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
|
|
"prepare_wall_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
|
|
"prepare_profile_total_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
|
|
"prepare_profile": dict(state.prepare_profile),
|
|
"norm_text": state.norm_text,
|
|
"norm_prompt_text": state.norm_prompt_text,
|
|
}
|
|
request_profiles.append({"request_id": state.request_id, "profile": dict(request_profile)})
|
|
self.api._merge_request_state_profile(state.request_id, request_profile)
|
|
semantic_len = int(job_result.get("semantic_len", 0))
|
|
finish_reason = str(job_result.get("finish_reason", "unknown"))
|
|
finished.append(
|
|
{
|
|
"request_id": state.request_id,
|
|
"semantic_len": semantic_len,
|
|
"finish_idx": int(job_result.get("finish_idx", job_result.get("decode_steps", 0))),
|
|
"finish_reason": finish_reason,
|
|
}
|
|
)
|
|
finish_reason_counts[finish_reason] = finish_reason_counts.get(finish_reason, 0) + 1
|
|
total_semantic_len += semantic_len
|
|
return SchedulerDebugExecution(
|
|
payload={
|
|
"message": "success",
|
|
"request_count": len(states),
|
|
"max_steps": int(max_steps),
|
|
"batch_profile": {
|
|
"request_count": int(len(states)),
|
|
"max_steps": int(max_steps),
|
|
"prepare_batch_wall_ms": float(prepare_batch_wall_ms),
|
|
"decode_batch_wall_ms": float(decode_batch_wall_ms),
|
|
"request_total_ms": float(request_total_ms),
|
|
"total_semantic_len": int(total_semantic_len),
|
|
"finish_reason_counts": finish_reason_counts,
|
|
},
|
|
"requests": self._summarize_scheduler_states(states),
|
|
"finished": finished,
|
|
"request_profiles": request_profiles,
|
|
"request_traces": self.api._collect_request_summaries(request_ids),
|
|
}
|
|
)
|
|
|
|
async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution:
|
|
request_start = time.perf_counter()
|
|
prepare_start = request_start
|
|
normalized = self.api._normalize_engine_request(
|
|
payload,
|
|
request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"),
|
|
)
|
|
spec = self._build_scheduler_submit_spec(normalized)
|
|
deadline_ts = None
|
|
timeout_sec = normalized.timeout_sec
|
|
if timeout_sec is not None:
|
|
try:
|
|
deadline_ts = request_start + float(timeout_sec)
|
|
except Exception:
|
|
deadline_ts = None
|
|
self.api._register_request_state(
|
|
request_id=spec.request_id,
|
|
api_mode="scheduler_submit",
|
|
backend="scheduler_v1",
|
|
media_type=normalized.media_type,
|
|
response_streaming=False,
|
|
deadline_ts=deadline_ts,
|
|
meta=self.api._build_request_meta(normalized.to_payload()),
|
|
)
|
|
self.api._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_submit"})
|
|
spec_ready_at = time.perf_counter()
|
|
prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0)
|
|
self.api._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, {"prepare_spec_build_ms": prepare_spec_build_ms})
|
|
try:
|
|
state, prepare_exec_started_at, prepare_exec_finished_at = await self.api._prepare_state_via_engine_gpu_queue(
|
|
spec=spec,
|
|
prepare_submit_at=spec_ready_at,
|
|
engine_request_id=spec.request_id,
|
|
)
|
|
except Exception as exc:
|
|
self.api._fail_request_state(spec.request_id, str(exc))
|
|
raise
|
|
prepare_wall_ms = max(0.0, (prepare_exec_finished_at - spec_ready_at) * 1000.0)
|
|
prepare_executor_queue_ms = max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0)
|
|
prepare_executor_run_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0)
|
|
prepare_profile = dict(state.prepare_profile)
|
|
prepare_profile_total_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms))
|
|
prepare_profile_wall_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms))
|
|
prepare_other_ms = max(0.0, prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_executor_run_ms)
|
|
self.api._update_request_state(
|
|
spec.request_id,
|
|
EngineStatus.READY_FOR_PREFILL,
|
|
{
|
|
"prepare_wall_ms": prepare_wall_ms,
|
|
"prepare_profile_total_ms": prepare_profile_total_ms,
|
|
"prepare_profile": prepare_profile,
|
|
},
|
|
)
|
|
api_after_prepare_start = time.perf_counter()
|
|
loop = asyncio.get_running_loop()
|
|
done_future = loop.create_future()
|
|
await self.api._enqueue_prepared_state_for_dispatch(
|
|
state=state,
|
|
speed_factor=float(normalized.speed_factor),
|
|
sample_steps=int(normalized.sample_steps),
|
|
media_type=normalized.media_type,
|
|
super_sampling=bool(normalized.super_sampling),
|
|
prepare_wall_ms=prepare_wall_ms,
|
|
prepare_profile_total_ms=prepare_profile_total_ms,
|
|
done_loop=loop,
|
|
done_future=done_future,
|
|
engine_request_id=spec.request_id,
|
|
timeout_sec=normalized.timeout_sec,
|
|
)
|
|
api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0)
|
|
try:
|
|
job = await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0))
|
|
except Exception as exc:
|
|
self.api._fail_request_state(spec.request_id, str(exc))
|
|
raise
|
|
wait_return_at = time.perf_counter()
|
|
if job.error is not None:
|
|
raise RuntimeError(job.error)
|
|
if job.audio_data is None or job.sample_rate is None or job.result is None:
|
|
self.api._fail_request_state(spec.request_id, f"{job.request_id} finished without audio result")
|
|
raise RuntimeError(f"{job.request_id} finished without audio result")
|
|
pack_start = time.perf_counter()
|
|
audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue()
|
|
pack_end = time.perf_counter()
|
|
pack_ms = (pack_end - pack_start) * 1000.0
|
|
api_wait_result_ms = 0.0
|
|
if job.result_ready_time is not None:
|
|
api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0)
|
|
response_ready_at = time.perf_counter()
|
|
response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0)
|
|
submit_profile = self.api._build_scheduler_submit_profile(
|
|
backend="scheduler_v1",
|
|
request_start=request_start,
|
|
response_ready_at=response_ready_at,
|
|
audio_bytes=len(audio_data),
|
|
sample_rate=int(job.sample_rate),
|
|
prepare_spec_build_ms=prepare_spec_build_ms,
|
|
prepare_wall_ms=prepare_wall_ms,
|
|
prepare_executor_queue_ms=prepare_executor_queue_ms,
|
|
prepare_executor_run_ms=prepare_executor_run_ms,
|
|
prepare_profile_total_ms=prepare_profile_total_ms,
|
|
prepare_profile_wall_ms=prepare_profile_wall_ms,
|
|
prepare_other_ms=prepare_other_ms,
|
|
engine_policy_wait_ms=float(job.result.get("engine_policy_wait_ms", 0.0)),
|
|
api_after_prepare_ms=api_after_prepare_ms,
|
|
api_wait_result_ms=api_wait_result_ms,
|
|
pack_ms=pack_ms,
|
|
response_overhead_ms=response_overhead_ms,
|
|
worker_profile=dict(job.result or {}),
|
|
)
|
|
headers = self.api._build_scheduler_submit_headers(
|
|
request_id=job.request_id,
|
|
media_type=job.media_type,
|
|
sample_rate=int(job.sample_rate),
|
|
profile=submit_profile,
|
|
)
|
|
self.api._merge_request_state_profile(
|
|
spec.request_id,
|
|
dict(submit_profile, response_headers_emitted=True),
|
|
)
|
|
return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=str(job.media_type), headers=headers)
|