GPT-SoVITS/GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py
baicai-1145 6a822b28c3 Enhance TTS API with improved request handling and asynchronous processing
Refactor api_v2.py and api_v3.py to update sampling parameters and weight paths for better clarity and support for v3/v4 vocoders. Introduce new methods in PrepareCoordinator for handling empty text features and improve profiling capabilities. Additionally, update unified engine components to streamline audio processing and state management, enhancing overall performance and maintainability of the TTS system.
2026-03-12 01:27:19 +08:00

190 lines
8.1 KiB
Python

from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import NormalizedEngineRequest, ReferenceRegistry
def normalize_lang(value: str | None) -> str | None:
if value in [None, ""]:
return value
return str(value).lower()
def apply_default_reference(reference_registry: ReferenceRegistry, req: dict) -> dict:
normalized = dict(req)
default_ref = reference_registry.get_default()
if normalized.get("ref_audio_path") in [None, ""] and default_ref.ref_audio_path not in [None, ""]:
normalized["ref_audio_path"] = default_ref.ref_audio_path
if "text_lang" in normalized:
normalized["text_lang"] = normalize_lang(normalized.get("text_lang"))
if "prompt_lang" in normalized:
normalized["prompt_lang"] = normalize_lang(normalized.get("prompt_lang"))
return normalized
def check_params(tts: TTS, cut_method_names: Sequence[str], req: dict) -> Optional[str]:
text = req.get("text", "")
text_lang = req.get("text_lang", "")
ref_audio_path = req.get("ref_audio_path", "")
media_type = req.get("media_type", "wav")
prompt_lang = req.get("prompt_lang", "")
text_split_method = req.get("text_split_method", "cut5")
if ref_audio_path in [None, ""]:
return "ref_audio_path is required"
if text in [None, ""]:
return "text is required"
if text_lang in [None, ""]:
return "text_lang is required"
if text_lang.lower() not in tts.configs.languages:
return f"text_lang: {text_lang} is not supported in version {tts.configs.version}"
if prompt_lang in [None, ""]:
return "prompt_lang is required"
if prompt_lang.lower() not in tts.configs.languages:
return f"prompt_lang: {prompt_lang} is not supported in version {tts.configs.version}"
if media_type not in ["wav", "raw", "ogg", "aac"]:
return f"media_type: {media_type} is not supported"
if text_split_method not in cut_method_names:
return f"text_split_method:{text_split_method} is not supported"
return None
def base_request_defaults() -> Dict[str, Any]:
return {
"request_id": None,
"text": None,
"text_lang": None,
"ref_audio_path": None,
"aux_ref_audio_paths": None,
"prompt_text": "",
"prompt_lang": None,
"top_k": 15,
"top_p": 1.0,
"temperature": 1.0,
"text_split_method": "cut5",
"batch_size": 1,
"batch_threshold": 0.75,
"speed_factor": 1.0,
"split_bucket": False,
"fragment_interval": 0.3,
"seed": -1,
"media_type": "wav",
"streaming_mode": False,
"return_fragment": False,
"fixed_length_chunk": False,
"response_streaming": False,
"parallel_infer": False,
"repetition_penalty": 1.35,
"sample_steps": 32,
"super_sampling": False,
"overlap_length": 2,
"min_chunk_length": 16,
"early_stop_num": -1,
"ready_step": 0,
"timeout_sec": None,
}
def normalize_streaming_mode(req: dict) -> dict:
normalized = dict(req)
streaming_mode = normalized.get("streaming_mode", False)
return_fragment = normalized.get("return_fragment", False)
if streaming_mode is False:
normalized["streaming_mode"] = False
normalized["return_fragment"] = False
normalized["fixed_length_chunk"] = False
elif streaming_mode == 0:
normalized["streaming_mode"] = False
normalized["return_fragment"] = False
normalized["fixed_length_chunk"] = False
elif streaming_mode == 1 or streaming_mode is True:
normalized["streaming_mode"] = False
normalized["return_fragment"] = True
normalized["fixed_length_chunk"] = False
elif streaming_mode == 2:
normalized["streaming_mode"] = True
normalized["return_fragment"] = False
normalized["fixed_length_chunk"] = False
elif streaming_mode == 3:
normalized["streaming_mode"] = True
normalized["return_fragment"] = False
normalized["fixed_length_chunk"] = True
else:
raise ValueError("the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)")
normalized["response_streaming"] = bool(normalized["streaming_mode"] or normalized["return_fragment"] or return_fragment)
return normalized
def is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool:
return aux_ref_audio_paths not in [None, [], ()]
def select_direct_backend(normalized: NormalizedEngineRequest) -> Tuple[str, str | None]:
return "scheduler_v1_direct", None
def normalize_engine_request(
*,
tts: TTS,
cut_method_names: Sequence[str],
reference_registry: ReferenceRegistry,
payload: dict | NormalizedEngineRequest,
request_id: str | None = None,
normalize_streaming: bool = False,
error_prefix: str = "request 参数非法: ",
) -> NormalizedEngineRequest:
if isinstance(payload, NormalizedEngineRequest):
normalized_payload = payload.to_payload()
else:
normalized_payload = base_request_defaults()
normalized_payload.update(dict(payload))
if request_id not in [None, ""]:
normalized_payload["request_id"] = str(request_id)
elif normalized_payload.get("request_id") in [None, ""]:
raise ValueError("request_id is required after normalization")
normalized_payload = apply_default_reference(reference_registry, normalized_payload)
if normalize_streaming:
normalized_payload = normalize_streaming_mode(normalized_payload)
error = check_params(tts, cut_method_names, normalized_payload)
if error is not None:
raise ValueError(f"{error_prefix}{error}")
timeout_sec = normalized_payload.get("timeout_sec")
parsed_timeout = None if timeout_sec in [None, ""] else float(timeout_sec)
aux_ref_audio_paths = normalized_payload.get("aux_ref_audio_paths")
normalized_aux_ref_audio_paths = None if aux_ref_audio_paths in [None, "", []] else [str(item) for item in aux_ref_audio_paths]
return NormalizedEngineRequest(
request_id=str(normalized_payload["request_id"]),
text=str(normalized_payload["text"]),
text_lang=str(normalized_payload["text_lang"]),
ref_audio_path=str(normalized_payload["ref_audio_path"]),
prompt_lang=str(normalized_payload["prompt_lang"]),
prompt_text="" if normalized_payload.get("prompt_text") is None else str(normalized_payload.get("prompt_text")),
aux_ref_audio_paths=normalized_aux_ref_audio_paths,
top_k=int(normalized_payload["top_k"]),
top_p=float(normalized_payload["top_p"]),
temperature=float(normalized_payload["temperature"]),
repetition_penalty=float(normalized_payload["repetition_penalty"]),
early_stop_num=int(normalized_payload.get("early_stop_num", -1)),
ready_step=int(normalized_payload.get("ready_step", 0)),
text_split_method=str(normalized_payload["text_split_method"]),
batch_size=int(normalized_payload["batch_size"]),
batch_threshold=float(normalized_payload["batch_threshold"]),
split_bucket=bool(normalized_payload["split_bucket"]),
speed_factor=float(normalized_payload["speed_factor"]),
fragment_interval=float(normalized_payload["fragment_interval"]),
seed=int(normalized_payload["seed"]),
media_type=str(normalized_payload["media_type"]),
streaming_mode=normalized_payload["streaming_mode"],
return_fragment=bool(normalized_payload.get("return_fragment", False)),
fixed_length_chunk=bool(normalized_payload.get("fixed_length_chunk", False)),
response_streaming=bool(normalized_payload.get("response_streaming", False)),
parallel_infer=bool(normalized_payload["parallel_infer"]),
sample_steps=int(normalized_payload["sample_steps"]),
super_sampling=bool(normalized_payload["super_sampling"]),
overlap_length=int(normalized_payload["overlap_length"]),
min_chunk_length=int(normalized_payload["min_chunk_length"]),
timeout_sec=parsed_timeout,
)