GPT-SoVITS/GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py
baicai-1145 3fd4f48651 Add unified engine API modules for direct and scheduler-based TTS processing
Introduce new modules including unified_engine_api_direct, unified_engine_api_profile, unified_engine_api_request, and unified_engine_api_scheduler. These additions enhance the TTS system by providing structured interfaces for direct TTS execution and scheduler-based processing. The new components support improved request handling, profiling, and state management, significantly enhancing the architecture and maintainability of the TTS framework.
2026-03-11 18:36:24 +08:00

200 lines
8.7 KiB
Python

from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import NormalizedEngineRequest, ReferenceRegistry
def normalize_lang(value: str | None) -> str | None:
if value in [None, ""]:
return value
return str(value).lower()
def apply_default_reference(reference_registry: ReferenceRegistry, req: dict) -> dict:
normalized = dict(req)
default_ref = reference_registry.get_default()
if normalized.get("ref_audio_path") in [None, ""] and default_ref.ref_audio_path not in [None, ""]:
normalized["ref_audio_path"] = default_ref.ref_audio_path
if "text_lang" in normalized:
normalized["text_lang"] = normalize_lang(normalized.get("text_lang"))
if "prompt_lang" in normalized:
normalized["prompt_lang"] = normalize_lang(normalized.get("prompt_lang"))
return normalized
def check_params(tts: TTS, cut_method_names: Sequence[str], req: dict) -> Optional[str]:
text = req.get("text", "")
text_lang = req.get("text_lang", "")
ref_audio_path = req.get("ref_audio_path", "")
media_type = req.get("media_type", "wav")
prompt_lang = req.get("prompt_lang", "")
text_split_method = req.get("text_split_method", "cut5")
if ref_audio_path in [None, ""]:
return "ref_audio_path is required"
if text in [None, ""]:
return "text is required"
if text_lang in [None, ""]:
return "text_lang is required"
if text_lang.lower() not in tts.configs.languages:
return f"text_lang: {text_lang} is not supported in version {tts.configs.version}"
if prompt_lang in [None, ""]:
return "prompt_lang is required"
if prompt_lang.lower() not in tts.configs.languages:
return f"prompt_lang: {prompt_lang} is not supported in version {tts.configs.version}"
if media_type not in ["wav", "raw", "ogg", "aac"]:
return f"media_type: {media_type} is not supported"
if text_split_method not in cut_method_names:
return f"text_split_method:{text_split_method} is not supported"
return None
def base_request_defaults() -> Dict[str, Any]:
return {
"request_id": None,
"text": None,
"text_lang": None,
"ref_audio_path": None,
"aux_ref_audio_paths": None,
"prompt_text": "",
"prompt_lang": None,
"top_k": 15,
"top_p": 1.0,
"temperature": 1.0,
"text_split_method": "cut5",
"batch_size": 1,
"batch_threshold": 0.75,
"speed_factor": 1.0,
"split_bucket": False,
"fragment_interval": 0.3,
"seed": -1,
"media_type": "wav",
"streaming_mode": False,
"return_fragment": False,
"fixed_length_chunk": False,
"response_streaming": False,
"parallel_infer": False,
"repetition_penalty": 1.35,
"sample_steps": 32,
"super_sampling": False,
"overlap_length": 2,
"min_chunk_length": 16,
"early_stop_num": -1,
"ready_step": 0,
"timeout_sec": None,
}
def normalize_streaming_mode(req: dict) -> dict:
normalized = dict(req)
streaming_mode = normalized.get("streaming_mode", False)
return_fragment = normalized.get("return_fragment", False)
if streaming_mode is False:
normalized["streaming_mode"] = False
normalized["return_fragment"] = False
normalized["fixed_length_chunk"] = False
elif streaming_mode == 0:
normalized["streaming_mode"] = False
normalized["return_fragment"] = False
normalized["fixed_length_chunk"] = False
elif streaming_mode == 1 or streaming_mode is True:
normalized["streaming_mode"] = False
normalized["return_fragment"] = True
normalized["fixed_length_chunk"] = False
elif streaming_mode == 2:
normalized["streaming_mode"] = True
normalized["return_fragment"] = False
normalized["fixed_length_chunk"] = False
elif streaming_mode == 3:
normalized["streaming_mode"] = True
normalized["return_fragment"] = False
normalized["fixed_length_chunk"] = True
else:
raise ValueError("the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)")
normalized["response_streaming"] = bool(normalized["streaming_mode"] or normalized["return_fragment"] or return_fragment)
return normalized
def is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool:
return aux_ref_audio_paths not in [None, [], ()]
def select_direct_backend(normalized: NormalizedEngineRequest) -> Tuple[str, str | None]:
if normalized.response_streaming:
if normalized.return_fragment or normalized.fixed_length_chunk:
return "legacy_direct_fragment", "fragment_streaming_mode"
return "legacy_direct_streaming", "streaming_mode"
if is_aux_ref_enabled(normalized.aux_ref_audio_paths):
return "legacy_direct_aux_ref", "aux_ref_audio_paths"
if normalized.super_sampling:
return "legacy_direct_super_sampling", "super_sampling"
if normalized.prompt_text in [None, ""]:
return "legacy_direct_missing_prompt", "missing_prompt_text"
return "scheduler_v1_direct", None
def normalize_engine_request(
*,
tts: TTS,
cut_method_names: Sequence[str],
reference_registry: ReferenceRegistry,
payload: dict | NormalizedEngineRequest,
request_id: str | None = None,
normalize_streaming: bool = False,
error_prefix: str = "request 参数非法: ",
) -> NormalizedEngineRequest:
if isinstance(payload, NormalizedEngineRequest):
normalized_payload = payload.to_payload()
else:
normalized_payload = base_request_defaults()
normalized_payload.update(dict(payload))
if request_id not in [None, ""]:
normalized_payload["request_id"] = str(request_id)
elif normalized_payload.get("request_id") in [None, ""]:
raise ValueError("request_id is required after normalization")
normalized_payload = apply_default_reference(reference_registry, normalized_payload)
if normalize_streaming:
normalized_payload = normalize_streaming_mode(normalized_payload)
error = check_params(tts, cut_method_names, normalized_payload)
if error is not None:
raise ValueError(f"{error_prefix}{error}")
timeout_sec = normalized_payload.get("timeout_sec")
parsed_timeout = None if timeout_sec in [None, ""] else float(timeout_sec)
aux_ref_audio_paths = normalized_payload.get("aux_ref_audio_paths")
normalized_aux_ref_audio_paths = None if aux_ref_audio_paths in [None, "", []] else [str(item) for item in aux_ref_audio_paths]
return NormalizedEngineRequest(
request_id=str(normalized_payload["request_id"]),
text=str(normalized_payload["text"]),
text_lang=str(normalized_payload["text_lang"]),
ref_audio_path=str(normalized_payload["ref_audio_path"]),
prompt_lang=str(normalized_payload["prompt_lang"]),
prompt_text="" if normalized_payload.get("prompt_text") is None else str(normalized_payload.get("prompt_text")),
aux_ref_audio_paths=normalized_aux_ref_audio_paths,
top_k=int(normalized_payload["top_k"]),
top_p=float(normalized_payload["top_p"]),
temperature=float(normalized_payload["temperature"]),
repetition_penalty=float(normalized_payload["repetition_penalty"]),
early_stop_num=int(normalized_payload.get("early_stop_num", -1)),
ready_step=int(normalized_payload.get("ready_step", 0)),
text_split_method=str(normalized_payload["text_split_method"]),
batch_size=int(normalized_payload["batch_size"]),
batch_threshold=float(normalized_payload["batch_threshold"]),
split_bucket=bool(normalized_payload["split_bucket"]),
speed_factor=float(normalized_payload["speed_factor"]),
fragment_interval=float(normalized_payload["fragment_interval"]),
seed=int(normalized_payload["seed"]),
media_type=str(normalized_payload["media_type"]),
streaming_mode=normalized_payload["streaming_mode"],
return_fragment=bool(normalized_payload.get("return_fragment", False)),
fixed_length_chunk=bool(normalized_payload.get("fixed_length_chunk", False)),
response_streaming=bool(normalized_payload.get("response_streaming", False)),
parallel_infer=bool(normalized_payload["parallel_infer"]),
sample_steps=int(normalized_payload["sample_steps"]),
super_sampling=bool(normalized_payload["super_sampling"]),
overlap_length=int(normalized_payload["overlap_length"]),
min_chunk_length=int(normalized_payload["min_chunk_length"]),
timeout_sec=parsed_timeout,
)