Enhance TTS API with improved request handling and asynchronous processing

Refactor api_v2.py and api_v3.py to update sampling parameters and weight paths for better clarity and support for v3/v4 vocoders. Introduce new methods in PrepareCoordinator for handling empty text features and improve profiling capabilities. Additionally, update unified engine components to streamline audio processing and state management, enhancing overall performance and maintainability of the TTS system.
This commit is contained in:
baicai-1145 2026-03-12 01:27:19 +08:00
parent d453a8e47c
commit 6a822b28c3
19 changed files with 494 additions and 149 deletions

View File

@ -323,7 +323,7 @@ class TTS_Config:
assert isinstance(configs, dict) assert isinstance(configs, dict)
configs_ = deepcopy(self.default_configs) configs_ = deepcopy(self.default_configs)
configs_.update(configs) configs_.update(configs)
self.configs: dict = configs_.get("custom", configs_["v2"]) self.configs: dict = configs_.get("custom", configs_["v2ProPlus"])
self.default_configs = deepcopy(configs_) self.default_configs = deepcopy(configs_)
self.device = self.configs.get("device", torch.device("cpu")) self.device = self.configs.get("device", torch.device("cpu"))
@ -1872,19 +1872,19 @@ class TTS:
self.init_sr_model() self.init_sr_model()
if not self.sr_model_not_exist: if not self.sr_model_not_exist:
audio, sr = self.sr_model(audio.unsqueeze(0), sr) audio, sr = self.sr_model(audio.unsqueeze(0), sr)
max_audio = np.abs(audio).max() if isinstance(audio, torch.Tensor):
max_audio = float(torch.abs(audio).max().item())
else:
max_audio = float(np.abs(audio).max())
if max_audio > 1: if max_audio > 1:
audio /= max_audio audio /= max_audio
audio = (audio * 32768).astype(np.int16)
t2 = time.perf_counter() t2 = time.perf_counter()
print(f"超采样用时:{t2 - t1:.3f}s") print(f"超采样用时:{t2 - t1:.3f}s")
if isinstance(audio, torch.Tensor):
audio = audio.detach().float().cpu().numpy()
else: else:
# audio = audio.float() * 32768 audio = np.asarray(audio)
# audio = audio.to(dtype=torch.int16).clamp(-32768, 32767).cpu().numpy() audio = (audio.reshape(-1) * 32768).astype(np.int16)
audio = audio.cpu().numpy()
audio = (audio * 32768).astype(np.int16)
# try: # try:
@ -2036,20 +2036,23 @@ class TTS:
phones: torch.Tensor, phones: torch.Tensor,
prompt_semantic: torch.Tensor, prompt_semantic: torch.Tensor,
prompt_phones: torch.Tensor, prompt_phones: torch.Tensor,
refer_spec: tuple, refer_spec: tuple | List[tuple],
raw_audio: torch.Tensor, raw_audio: torch.Tensor,
raw_sr: int, raw_sr: int,
speed: float = 1.0, speed: float = 1.0,
sample_steps: int = 32, sample_steps: int = 32,
): ):
refer_audio_spec, audio_tensor = refer_spec refer_specs = list(refer_spec) if isinstance(refer_spec, list) else [refer_spec]
refer_audio_spec, audio_tensor = refer_specs[0]
if not self.configs.use_vocoder: if not self.configs.use_vocoder:
refer_audio_spec_list = [refer_audio_spec.to(dtype=self.precision, device=self.configs.device)] refer_audio_spec_list = [item[0].to(dtype=self.precision, device=self.configs.device) for item in refer_specs]
sv_emb = None sv_emb = None
if self.is_v2pro: if self.is_v2pro:
if audio_tensor is None: sv_emb = []
for _, audio_tensor_item in refer_specs:
if audio_tensor_item is None:
raise ValueError(i18n("v2Pro request-local synthesis 缺少 16k 参考音频")) raise ValueError(i18n("v2Pro request-local synthesis 缺少 16k 参考音频"))
sv_emb = self.sv_model.compute_embedding3(audio_tensor).to(self.configs.device) sv_emb.append(self.sv_model.compute_embedding3(audio_tensor_item).to(self.configs.device))
return self.vits_model.decode( return self.vits_model.decode(
semantic_tokens, semantic_tokens,
phones, phones,
@ -2075,7 +2078,7 @@ class TTS:
self, self,
semantic_tokens_list: List[torch.Tensor], semantic_tokens_list: List[torch.Tensor],
phones_list: List[torch.Tensor], phones_list: List[torch.Tensor],
refer_specs: List[tuple], refer_specs: List[tuple | List[tuple]],
speeds: List[float] | None = None, speeds: List[float] | None = None,
sample_steps_list: List[int] | None = None, sample_steps_list: List[int] | None = None,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -2118,7 +2121,11 @@ class TTS:
semantic_lengths.append(semantic_len) semantic_lengths.append(semantic_len)
phone_lengths.append(phone_len) phone_lengths.append(phone_len)
refer_audio_spec, audio_tensor = refer_specs[batch_index] refer_spec_item = refer_specs[batch_index]
refer_spec_group = list(refer_spec_item) if isinstance(refer_spec_item, list) else [refer_spec_item]
if len(refer_spec_group) != 1:
raise ValueError("batched request-local synthesis 暂不支持单请求多参考音频")
refer_audio_spec, audio_tensor = refer_spec_group[0]
refer_audio_specs.append(refer_audio_spec.to(dtype=self.precision, device=device)) refer_audio_specs.append(refer_audio_spec.to(dtype=self.precision, device=device))
if self.is_v2pro: if self.is_v2pro:
if audio_tensor is None: if audio_tensor is None:

View File

@ -12,6 +12,7 @@ from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import (
PreparedTextFeatures, PreparedTextFeatures,
SchedulerRequestSpec, SchedulerRequestSpec,
T2SRequestState, T2SRequestState,
build_empty_text_features,
build_request_state_from_parts, build_request_state_from_parts,
normalize_sentence, normalize_sentence,
) )
@ -118,6 +119,21 @@ class PrepareCoordinator:
def _prepare_text_cpu(self, text: str, language: str): def _prepare_text_cpu(self, text: str, language: str):
return self.tts.prepare_text_segments(text, language) return self.tts.prepare_text_segments(text, language)
@staticmethod
def _build_empty_text_features_like(reference: PreparedTextFeatures | None = None) -> PreparedTextFeatures:
feature_dim = 1024
dtype = None
if reference is not None:
try:
feature_dim = int(reference.bert_features.shape[0])
dtype = reference.bert_features.dtype
except Exception:
pass
return build_empty_text_features(
feature_dim=int(feature_dim),
dtype=(dtype if dtype is not None else None) or __import__("torch").float32,
)
def _build_text_features(self, prepared_segments, language: str, cpu_run_ms: float) -> PreparedTextFeatures: def _build_text_features(self, prepared_segments, language: str, cpu_run_ms: float) -> PreparedTextFeatures:
profile: Dict[str, float] = {"cpu_preprocess_ms": float(cpu_run_ms)} profile: Dict[str, float] = {"cpu_preprocess_ms": float(cpu_run_ms)}
branch_start = time.perf_counter() branch_start = time.perf_counter()
@ -139,6 +155,9 @@ class PrepareCoordinator:
return await loop.run_in_executor(executor, self._run_profiled, fn, float(submit_at), *args) return await loop.run_in_executor(executor, self._run_profiled, fn, float(submit_at), *args)
async def _run_text_cpu_stage(self, text: str, language: str) -> ProfiledResult: async def _run_text_cpu_stage(self, text: str, language: str) -> ProfiledResult:
if text in [None, ""]:
submit_at = time.perf_counter()
return ProfiledResult(result=[], submit_at=submit_at, started_at=submit_at, finished_at=submit_at)
executor = getattr(self.tts, "prepare_text_cpu_executor", None) executor = getattr(self.tts, "prepare_text_cpu_executor", None)
if executor is None: if executor is None:
submit_at = time.perf_counter() submit_at = time.perf_counter()
@ -164,19 +183,71 @@ class PrepareCoordinator:
prompt_cpu_run_ms: float, prompt_cpu_run_ms: float,
target_cpu_run_ms: float, target_cpu_run_ms: float,
) -> tuple[ProfiledResult, ProfiledResult]: ) -> tuple[ProfiledResult, ProfiledResult]:
prompt_is_empty = len(prompt_segments or []) == 0
if self.text_feature_executor is not None: if self.text_feature_executor is not None:
prompt_feature_task = asyncio.create_task( target_feature_task = asyncio.create_task(self._run_text_feature_stage(target_segments, None, target_cpu_run_ms))
self._run_text_feature_stage(prompt_segments, None, prompt_cpu_run_ms) if not prompt_is_empty:
) prompt_feature_task = asyncio.create_task(self._run_text_feature_stage(prompt_segments, None, prompt_cpu_run_ms))
target_feature_task = asyncio.create_task(
self._run_text_feature_stage(target_segments, None, target_cpu_run_ms)
)
return await asyncio.gather(prompt_feature_task, target_feature_task) return await asyncio.gather(prompt_feature_task, target_feature_task)
target_profiled = await target_feature_task
submit_at = time.perf_counter()
prompt_profiled = ProfiledResult(
result=self._build_empty_text_features_like(target_profiled.result),
submit_at=float(submit_at),
started_at=float(submit_at),
finished_at=float(submit_at),
)
return prompt_profiled, target_profiled
prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)}
target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)} target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)}
submit_at = time.perf_counter() submit_at = time.perf_counter()
started_at = float(submit_at) started_at = float(submit_at)
if prompt_is_empty:
target_result_raw = await self.tts.build_text_features_from_segments_async(
target_segments,
profile=target_profile,
)
prompt_result = self._build_empty_text_features_like(
PreparedTextFeatures(
phones=target_result_raw[0],
bert_features=target_result_raw[1],
norm_text=target_result_raw[2],
profile=target_profile,
total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)),
cpu_preprocess_ms=float(target_cpu_run_ms),
)
)
finished_at = time.perf_counter()
prompt_profiled = ProfiledResult(
result=prompt_result,
submit_at=float(submit_at),
started_at=float(submit_at),
finished_at=float(submit_at),
)
target_result = PreparedTextFeatures(
phones=target_result_raw[0],
bert_features=target_result_raw[1],
norm_text=target_result_raw[2],
profile=target_profile,
total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)),
cpu_preprocess_ms=float(target_cpu_run_ms),
)
target_profiled = ProfiledResult(
result=target_result,
submit_at=float(submit_at),
started_at=started_at,
finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0),
)
if finished_at > target_profiled.finished_at:
target_result.profile["bert_total_ms"] = max(
self._estimate_text_feature_run_ms(target_profile),
(finished_at - submit_at) * 1000.0,
)
else:
target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile)
return prompt_profiled, target_profiled
prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)}
prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async( prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async(
prompt_segments, prompt_segments,
target_segments, target_segments,

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass, field
import os
from pathlib import Path from pathlib import Path
import time import time
from typing import Any, Dict, List, Optional, Sequence, Tuple from typing import Any, Dict, List, Optional, Sequence, Tuple
@ -35,6 +36,7 @@ class SchedulerRequestSpec:
temperature: float temperature: float
repetition_penalty: float repetition_penalty: float
early_stop_num: int early_stop_num: int
aux_ref_audio_paths: List[str] = field(default_factory=list)
ready_step: int = 0 ready_step: int = 0
@ -54,6 +56,7 @@ class T2SRequestState:
all_bert_features: torch.Tensor all_bert_features: torch.Tensor
prompt_semantic: torch.LongTensor prompt_semantic: torch.LongTensor
refer_spec: Tuple[torch.Tensor, Optional[torch.Tensor]] refer_spec: Tuple[torch.Tensor, Optional[torch.Tensor]]
aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
raw_audio: torch.Tensor raw_audio: torch.Tensor
raw_sr: int raw_sr: int
top_k: int top_k: int
@ -113,6 +116,21 @@ class PreparedTextFeatures:
cpu_preprocess_ms: float cpu_preprocess_ms: float
def build_empty_text_features(
*,
feature_dim: int = 1024,
dtype: torch.dtype = torch.float32,
) -> PreparedTextFeatures:
return PreparedTextFeatures(
phones=[],
bert_features=torch.empty((int(feature_dim), 0), dtype=dtype),
norm_text="",
profile={"cpu_preprocess_ms": 0.0, "bert_total_ms": 0.0},
total_ms=0.0,
cpu_preprocess_ms=0.0,
)
def normalize_sentence(text: str, language: str) -> str: def normalize_sentence(text: str, language: str) -> str:
text = text.strip("\n").strip() text = text.strip("\n").strip()
if not text: if not text:
@ -171,6 +189,14 @@ def build_request_state_from_parts(
bundle_profile = ref_audio_bundle.get("profile", {}) bundle_profile = ref_audio_bundle.get("profile", {})
prompt_semantic = ref_audio_bundle["prompt_semantic"].long() prompt_semantic = ref_audio_bundle["prompt_semantic"].long()
spec_audio, audio_16k = ref_audio_bundle["refer_spec"] spec_audio, audio_16k = ref_audio_bundle["refer_spec"]
aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] = []
for aux_ref_audio_path in list(getattr(spec, "aux_ref_audio_paths", []) or []):
if aux_ref_audio_path in [None, ""]:
continue
if not os.path.exists(str(aux_ref_audio_path)):
continue
aux_spec_audio, aux_audio_16k, _, _ = tts.extract_ref_spec(str(aux_ref_audio_path))
aux_refer_specs.append((aux_spec_audio, aux_audio_16k))
raw_audio = ref_audio_bundle["raw_audio"] raw_audio = ref_audio_bundle["raw_audio"]
raw_sr = int(ref_audio_bundle["raw_sr"]) raw_sr = int(ref_audio_bundle["raw_sr"])
prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms)) prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms))
@ -182,9 +208,9 @@ def build_request_state_from_parts(
phones_tensor = torch.LongTensor(target_result.phones).to(tts.configs.device) phones_tensor = torch.LongTensor(target_result.phones).to(tts.configs.device)
prompt_phones_tensor = torch.LongTensor(prompt_result.phones).to(tts.configs.device) prompt_phones_tensor = torch.LongTensor(prompt_result.phones).to(tts.configs.device)
all_phones = torch.LongTensor(prompt_result.phones + target_result.phones).to(tts.configs.device) all_phones = torch.LongTensor(prompt_result.phones + target_result.phones).to(tts.configs.device)
all_bert_features = torch.cat([prompt_result.bert_features, target_result.bert_features], dim=1).to( prompt_bert_features = prompt_result.bert_features.to(dtype=tts.precision, device=tts.configs.device)
dtype=tts.precision, device=tts.configs.device target_bert_features = target_result.bert_features.to(dtype=tts.precision, device=tts.configs.device)
) all_bert_features = torch.cat([prompt_bert_features, target_bert_features], dim=1)
_sync_device(device) _sync_device(device)
tensorize_ms = (time.perf_counter() - tensorize_start) * 1000.0 tensorize_ms = (time.perf_counter() - tensorize_start) * 1000.0
@ -280,6 +306,7 @@ def build_request_state_from_parts(
all_bert_features=all_bert_features, all_bert_features=all_bert_features,
prompt_semantic=prompt_semantic, prompt_semantic=prompt_semantic,
refer_spec=(spec_audio, audio_16k), refer_spec=(spec_audio, audio_16k),
aux_refer_specs=aux_refer_specs,
raw_audio=raw_audio, raw_audio=raw_audio,
raw_sr=raw_sr, raw_sr=raw_sr,
top_k=spec.top_k, top_k=spec.top_k,
@ -301,10 +328,16 @@ def prepare_request_state(
prepare_sync_start = time.perf_counter() prepare_sync_start = time.perf_counter()
prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang)
text = spec.text.strip("\n") text = spec.text.strip("\n")
prompt_result = prepare_text_features(tts, prompt_text, spec.prompt_lang)
target_result = prepare_text_features(tts, text, spec.text_lang) target_result = prepare_text_features(tts, text, spec.text_lang)
if target_result.phones is None: if target_result.phones is None:
raise ValueError(f"{spec.request_id} text preprocessing returned no phones") raise ValueError(f"{spec.request_id} text preprocessing returned no phones")
if prompt_text in [None, ""]:
prompt_result = build_empty_text_features(
feature_dim=int(target_result.bert_features.shape[0]),
dtype=target_result.bert_features.dtype,
)
else:
prompt_result = prepare_text_features(tts, prompt_text, spec.prompt_lang)
ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path)) ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path))
return build_request_state_from_parts( return build_request_state_from_parts(
tts=tts, tts=tts,

View File

@ -119,6 +119,7 @@ class EngineApiFacade:
speed_factor: float, speed_factor: float,
sample_steps: int, sample_steps: int,
media_type: str, media_type: str,
super_sampling: bool,
prepare_wall_ms: float, prepare_wall_ms: float,
prepare_profile_total_ms: float, prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None, done_loop: asyncio.AbstractEventLoop | None,
@ -131,6 +132,7 @@ class EngineApiFacade:
speed_factor=speed_factor, speed_factor=speed_factor,
sample_steps=sample_steps, sample_steps=sample_steps,
media_type=media_type, media_type=media_type,
super_sampling=super_sampling,
prepare_wall_ms=prepare_wall_ms, prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms, prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=done_loop, done_loop=done_loop,

View File

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import queue
import threading
import time import time
import uuid import uuid
from io import BytesIO from io import BytesIO
@ -122,63 +124,28 @@ class EngineApiDirectFlow:
payload["response_streaming"] = False payload["response_streaming"] = False
return self.api._normalize_engine_request(payload, error_prefix="segment request 参数非法: ") return self.api._normalize_engine_request(payload, error_prefix="segment request 参数非法: ")
async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: async def _execute_single_segment_scheduler_job(
request_start = time.perf_counter() self,
request_id = normalized.request_id normalized: NormalizedEngineRequest,
media_type = normalized.media_type *,
segment_texts = self._segment_direct_text(normalized) segment_request: NormalizedEngineRequest,
if not segment_texts: ) -> tuple[SchedulerPendingJob, Dict[str, Any]]:
raise ValueError("text preprocessing returned no valid segments") spec = self.api._build_scheduler_submit_spec(segment_request)
self.api._update_request_state( state, prepare_exec_started_at, prepare_exec_finished_at = await self.api._prepare_state_via_engine_gpu_queue(
request_id,
EngineStatus.CPU_PREPARING,
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)},
)
segment_specs = []
for segment_index, segment_text in enumerate(segment_texts):
segment_request = self._build_segment_request(
normalized,
request_id=f"{request_id}_seg_{segment_index:03d}",
text=segment_text,
)
segment_specs.append(self.api._build_scheduler_submit_spec(segment_request))
prepared_items = await asyncio.gather(
*[
self.api._prepare_state_via_engine_gpu_queue(
spec=spec, spec=spec,
prepare_submit_at=time.perf_counter(), prepare_submit_at=time.perf_counter(),
engine_request_id=None, engine_request_id=None,
) )
for spec in segment_specs
]
)
prepare_profiles: List[Dict[str, Any]] = []
loop = asyncio.get_running_loop()
done_futures: List[asyncio.Future] = []
self.api._update_request_state(
request_id,
EngineStatus.READY_FOR_PREFILL,
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_specs)},
)
for spec, (state, prepare_exec_started_at, prepare_exec_finished_at) in zip(segment_specs, prepared_items):
prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0)
prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms)) prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms))
prepare_profiles.append( loop = asyncio.get_running_loop()
{
"request_id": spec.request_id,
"prepare_wall_ms": prepare_wall_ms,
"prepare_profile_total_ms": prepare_profile_total_ms,
"prepare_profile": dict(state.prepare_profile),
}
)
done_future = loop.create_future() done_future = loop.create_future()
done_futures.append(done_future)
await self.api._enqueue_prepared_state_for_dispatch( await self.api._enqueue_prepared_state_for_dispatch(
state=state, state=state,
speed_factor=float(normalized.speed_factor), speed_factor=float(normalized.speed_factor),
sample_steps=int(normalized.sample_steps), sample_steps=int(normalized.sample_steps),
media_type=media_type, media_type=normalized.media_type,
super_sampling=bool(normalized.super_sampling),
prepare_wall_ms=prepare_wall_ms, prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms, prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=loop, done_loop=loop,
@ -186,6 +153,193 @@ class EngineApiDirectFlow:
engine_request_id=None, engine_request_id=None,
timeout_sec=normalized.timeout_sec, timeout_sec=normalized.timeout_sec,
) )
timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)
job: SchedulerPendingJob = await asyncio.wait_for(done_future, timeout=timeout_sec)
return job, {
"request_id": spec.request_id,
"prepare_wall_ms": prepare_wall_ms,
"prepare_profile_total_ms": prepare_profile_total_ms,
"prepare_profile": dict(state.prepare_profile),
}
def _iter_scheduler_direct_tts_bytes(self, normalized: NormalizedEngineRequest) -> Generator[bytes, None, None]:
request_start = time.perf_counter()
request_id = normalized.request_id
media_type = normalized.media_type
segment_texts = self._segment_direct_text(normalized)
if not segment_texts:
raise ValueError("text preprocessing returned no valid segments")
chunk_queue: queue.Queue[object] = queue.Queue(maxsize=8)
done_marker = object()
async def _produce_chunks() -> None:
self.api._update_request_state(
request_id,
EngineStatus.CPU_PREPARING,
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)},
)
sample_rate: int | None = None
current_media_type = media_type
chunk_count = 0
stream_total_bytes = 0
first_chunk_ms: float | None = None
prepare_profiles: List[Dict[str, Any]] = []
worker_profiles: List[Dict[str, Any]] = []
try:
for segment_index, segment_text in enumerate(segment_texts):
segment_request = self._build_segment_request(
normalized,
request_id=f"{request_id}_seg_{segment_index:03d}",
text=segment_text,
)
self.api._update_request_state(
request_id,
EngineStatus.READY_FOR_PREFILL,
{
"backend": "scheduler_v1_direct",
"backend_mode": "scheduler_v1_direct",
"segment_index": segment_index,
"segment_count": len(segment_texts),
},
)
job, prepare_profile = await self._execute_single_segment_scheduler_job(
normalized,
segment_request=segment_request,
)
prepare_profiles.append(prepare_profile)
if job.error is not None:
raise RuntimeError(job.error)
if job.audio_data is None or job.sample_rate is None or job.result is None:
raise RuntimeError(f"{job.request_id} finished without audio result")
worker_profiles.append(dict(job.result))
if sample_rate is None:
sample_rate = int(job.sample_rate)
first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0)
self.api._update_request_state(
request_id,
EngineStatus.STREAMING,
{
"backend": "scheduler_v1_direct",
"backend_mode": "scheduler_v1_direct",
"sample_rate": int(sample_rate),
},
)
if media_type == "wav":
header = wave_header_chunk(sample_rate=int(sample_rate))
chunk_count += 1
stream_total_bytes += len(header)
chunk_queue.put(header)
current_media_type = "raw"
packed_chunk = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), current_media_type).getvalue()
chunk_count += 1
stream_total_bytes += len(packed_chunk)
chunk_queue.put(packed_chunk)
if segment_index + 1 < len(segment_texts):
silence_samples = int(float(normalized.fragment_interval) * float(job.sample_rate))
if silence_samples > 0:
silence_chunk = np.zeros(silence_samples, dtype=np.int16)
packed_silence = pack_audio(
BytesIO(), silence_chunk, int(job.sample_rate), current_media_type
).getvalue()
chunk_count += 1
stream_total_bytes += len(packed_silence)
chunk_queue.put(packed_silence)
except Exception as exc:
self.api._fail_request_state(request_id, str(exc))
chunk_queue.put(exc)
else:
self.api._merge_request_state_profile(
request_id,
{
"prepare_aggregate": self.api._aggregate_numeric_dicts(
[item["prepare_profile"] for item in prepare_profiles]
),
"engine_policy_wait_ms": sum(
float(item.get("engine_policy_wait_ms", 0.0)) for item in worker_profiles
),
"engine_dispatch_wait_ms": sum(
float(item.get("engine_dispatch_wait_ms", 0.0)) for item in worker_profiles
),
},
)
direct_profile = self.api._build_direct_scheduler_profile(
backend="scheduler_v1_direct",
request_start=request_start,
response_ready_at=time.perf_counter(),
audio_bytes=stream_total_bytes,
sample_rate=int(sample_rate or 0),
segment_texts=segment_texts,
prepare_profiles=prepare_profiles,
worker_profiles=worker_profiles,
pack_ms=0.0,
response_overhead_ms=0.0,
)
self.api._complete_request_state(
request_id,
dict(direct_profile, streaming_completed=True, first_chunk_ms=first_chunk_ms),
)
finally:
chunk_queue.put(done_marker)
producer_thread = threading.Thread(target=lambda: asyncio.run(_produce_chunks()), daemon=True)
producer_thread.start()
while True:
item = chunk_queue.get()
if item is done_marker:
break
if isinstance(item, Exception):
raise item
yield item
async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution:
request_start = time.perf_counter()
request_id = normalized.request_id
media_type = normalized.media_type
segment_texts = self._segment_direct_text(normalized)
if not segment_texts:
raise ValueError("text preprocessing returned no valid segments")
if normalized.response_streaming:
return DirectTTSExecution(
media_type=media_type,
streaming=True,
audio_generator=self._iter_scheduler_direct_tts_bytes(normalized),
request_id=request_id,
)
self.api._update_request_state(
request_id,
EngineStatus.CPU_PREPARING,
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)},
)
segment_requests = [
self._build_segment_request(
normalized,
request_id=f"{request_id}_seg_{segment_index:03d}",
text=segment_text,
)
for segment_index, segment_text in enumerate(segment_texts)
]
prepare_profiles: List[Dict[str, Any]] = []
loop = asyncio.get_running_loop()
done_futures: List[asyncio.Future] = []
self.api._update_request_state(
request_id,
EngineStatus.READY_FOR_PREFILL,
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_requests)},
)
prepared_items = await asyncio.gather(
*[
self._execute_single_segment_scheduler_job(
normalized,
segment_request=segment_request,
)
for segment_request in segment_requests
]
)
for job, prepare_profile in prepared_items:
prepare_profiles.append(prepare_profile)
done_future = loop.create_future()
done_future.set_result(job)
done_futures.append(done_future)
self.api._update_request_state( self.api._update_request_state(
request_id, request_id,
EngineStatus.ACTIVE_DECODE, EngineStatus.ACTIVE_DECODE,

View File

@ -122,16 +122,6 @@ def is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool:
def select_direct_backend(normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: def select_direct_backend(normalized: NormalizedEngineRequest) -> Tuple[str, str | None]:
if normalized.response_streaming:
if normalized.return_fragment or normalized.fixed_length_chunk:
return "legacy_direct_fragment", "fragment_streaming_mode"
return "legacy_direct_streaming", "streaming_mode"
if is_aux_ref_enabled(normalized.aux_ref_audio_paths):
return "legacy_direct_aux_ref", "aux_ref_audio_paths"
if normalized.super_sampling:
return "legacy_direct_super_sampling", "super_sampling"
if normalized.prompt_text in [None, ""]:
return "legacy_direct_missing_prompt", "missing_prompt_text"
return "scheduler_v1_direct", None return "scheduler_v1_direct", None

View File

@ -6,7 +6,7 @@ import uuid
from io import BytesIO from io import BytesIO
from typing import Any, Dict, List from typing import Any, Dict, List
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState, run_scheduler_continuous from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, set_scheduler_seed from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, set_scheduler_seed
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerSubmitExecution from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerSubmitExecution
@ -67,39 +67,58 @@ class EngineApiSchedulerFlow:
async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
request_start = time.perf_counter() request_start = time.perf_counter()
set_scheduler_seed(seed) set_scheduler_seed(seed)
specs = self._build_scheduler_request_specs(request_items) normalized_requests: List[NormalizedEngineRequest] = []
request_ids = [spec.request_id for spec in specs] for index, payload in enumerate(request_items):
for spec in specs: normalized_requests.append(
self.api._normalize_engine_request(
payload,
request_id=str(payload.get("request_id") or f"req_{index:03d}"),
error_prefix=f"request[{index}] 参数非法: ",
)
)
specs = [normalized.to_scheduler_spec() for normalized in normalized_requests]
request_ids = [normalized.request_id for normalized in normalized_requests]
for normalized, spec in zip(normalized_requests, specs):
self.api._register_request_state( self.api._register_request_state(
request_id=spec.request_id, request_id=normalized.request_id,
api_mode="scheduler_debug", api_mode="scheduler_debug",
backend="scheduler_debug", backend="scheduler_debug",
media_type="wav", media_type=normalized.media_type,
response_streaming=False, response_streaming=False,
meta={ meta=self.api._build_request_meta(normalized.to_payload()),
"text_len": len(spec.text),
"prompt_text_len": len(spec.prompt_text),
"text_lang": spec.text_lang,
"prompt_lang": spec.prompt_lang,
"ref_audio_path": str(spec.ref_audio_path),
"ready_step": int(spec.ready_step),
},
) )
self.api._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"}) self.api._update_request_state(normalized.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"})
self.api._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, None) self.api._update_request_state(normalized.request_id, EngineStatus.CPU_PREPARING, None)
prepare_started_at = time.perf_counter() prepare_started_at = time.perf_counter()
original_worker_max_steps = int(self.api.scheduler_worker.max_steps)
original_decode_max_steps = int(self.api.scheduler_worker.decode_executor.max_steps)
try: try:
states = await self.api.scheduler_worker.prepare_states_batch_async(specs) self.api.scheduler_worker.max_steps = int(max_steps)
self.api.scheduler_worker.decode_executor.max_steps = int(max_steps)
prepared_payloads = await asyncio.gather(
*[
self.api._prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=time.perf_counter(),
engine_request_id=normalized.request_id,
)
for normalized, spec in zip(normalized_requests, specs)
]
)
except Exception as exc: except Exception as exc:
for request_id in request_ids: for request_id in request_ids:
self.api._fail_request_state(request_id, str(exc)) self.api._fail_request_state(request_id, str(exc))
raise raise
finally:
self.api.scheduler_worker.max_steps = int(original_worker_max_steps)
self.api.scheduler_worker.decode_executor.max_steps = int(original_decode_max_steps)
prepare_finished_at = time.perf_counter() prepare_finished_at = time.perf_counter()
prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0) prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0)
states = [payload[0] for payload in prepared_payloads]
for state in states: for state in states:
self.api._update_request_state( self.api._update_request_state(
state.request_id, state.request_id,
EngineStatus.ACTIVE_DECODE, EngineStatus.READY_FOR_PREFILL,
{ {
"prepare_profile": dict(state.prepare_profile), "prepare_profile": dict(state.prepare_profile),
"norm_text": state.norm_text, "norm_text": state.norm_text,
@ -108,7 +127,27 @@ class EngineApiSchedulerFlow:
) )
decode_started_at = time.perf_counter() decode_started_at = time.perf_counter()
try: try:
finished = run_scheduler_continuous(self.api.tts.t2s_model.model, states, max_steps=int(max_steps)) loop = asyncio.get_running_loop()
done_futures: List[asyncio.Future] = []
for normalized, state in zip(normalized_requests, states):
done_future = loop.create_future()
done_futures.append(done_future)
await self.api._enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=float(normalized.speed_factor),
sample_steps=int(normalized.sample_steps),
media_type=normalized.media_type,
super_sampling=bool(normalized.super_sampling),
prepare_wall_ms=float(state.prepare_profile.get("wall_total_ms", 0.0)),
prepare_profile_total_ms=float(state.prepare_profile.get("wall_total_ms", 0.0)),
done_loop=loop,
done_future=done_future,
engine_request_id=normalized.request_id,
timeout_sec=normalized.timeout_sec,
)
timeout_candidates = [float(item.timeout_sec) for item in normalized_requests if item.timeout_sec not in [None, ""]]
timeout_sec = max(timeout_candidates) if timeout_candidates else 60.0
jobs = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=float(timeout_sec)))
except Exception as exc: except Exception as exc:
for request_id in request_ids: for request_id in request_ids:
self.api._fail_request_state(request_id, str(exc)) self.api._fail_request_state(request_id, str(exc))
@ -116,46 +155,63 @@ class EngineApiSchedulerFlow:
decode_finished_at = time.perf_counter() decode_finished_at = time.perf_counter()
decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0) decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0)
request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0) request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0)
finished_map = {item.request_id: item for item in finished}
request_profiles: List[Dict[str, Any]] = [] request_profiles: List[Dict[str, Any]] = []
for state in states: finished: List[Dict[str, Any]] = []
item = finished_map.get(state.request_id) finish_reason_counts: Dict[str, int] = {}
if item is None: total_semantic_len = 0
for state, job in zip(states, jobs):
if job.error is not None:
self.api._fail_request_state(state.request_id, str(job.error))
raise RuntimeError(str(job.error))
if job.result is None:
self.api._fail_request_state(state.request_id, "scheduler_debug finished without result") self.api._fail_request_state(state.request_id, "scheduler_debug finished without result")
continue raise RuntimeError(f"{state.request_id} finished without result")
request_profile = self.api._build_scheduler_debug_request_profile( job_result = dict(job.result)
state=state, request_profile = {
item=item, **job_result,
batch_request_count=len(states), "backend": "scheduler_debug",
prepare_batch_wall_ms=prepare_batch_wall_ms, "backend_mode": "scheduler_debug",
decode_batch_wall_ms=decode_batch_wall_ms, "batch_request_count": int(len(states)),
batch_request_total_ms=request_total_ms, "batch_prepare_wall_ms": float(prepare_batch_wall_ms),
) "batch_decode_wall_ms": float(decode_batch_wall_ms),
request_profiles.append( "batch_request_total_ms": float(request_total_ms),
"prepare_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
"prepare_wall_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
"prepare_profile_total_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
"prepare_profile": dict(state.prepare_profile),
"norm_text": state.norm_text,
"norm_prompt_text": state.norm_prompt_text,
}
request_profiles.append({"request_id": state.request_id, "profile": dict(request_profile)})
self.api._merge_request_state_profile(state.request_id, request_profile)
semantic_len = int(job_result.get("semantic_len", 0))
finish_reason = str(job_result.get("finish_reason", "unknown"))
finished.append(
{ {
"request_id": state.request_id, "request_id": state.request_id,
"profile": dict(request_profile), "semantic_len": semantic_len,
"finish_idx": int(job_result.get("finish_idx", job_result.get("decode_steps", 0))),
"finish_reason": finish_reason,
} }
) )
self.api._complete_request_state( finish_reason_counts[finish_reason] = finish_reason_counts.get(finish_reason, 0) + 1
state.request_id, total_semantic_len += semantic_len
dict(request_profile),
)
return SchedulerDebugExecution( return SchedulerDebugExecution(
payload={ payload={
"message": "success", "message": "success",
"request_count": len(states), "request_count": len(states),
"max_steps": int(max_steps), "max_steps": int(max_steps),
"batch_profile": self.api._build_scheduler_debug_batch_profile( "batch_profile": {
request_count=len(states), "request_count": int(len(states)),
max_steps=int(max_steps), "max_steps": int(max_steps),
prepare_batch_wall_ms=prepare_batch_wall_ms, "prepare_batch_wall_ms": float(prepare_batch_wall_ms),
decode_batch_wall_ms=decode_batch_wall_ms, "decode_batch_wall_ms": float(decode_batch_wall_ms),
request_total_ms=request_total_ms, "request_total_ms": float(request_total_ms),
finished_items=finished, "total_semantic_len": int(total_semantic_len),
), "finish_reason_counts": finish_reason_counts,
},
"requests": self._summarize_scheduler_states(states), "requests": self._summarize_scheduler_states(states),
"finished": self._summarize_scheduler_finished(finished), "finished": finished,
"request_profiles": request_profiles, "request_profiles": request_profiles,
"request_traces": self.api._collect_request_summaries(request_ids), "request_traces": self.api._collect_request_summaries(request_ids),
} }
@ -222,6 +278,7 @@ class EngineApiSchedulerFlow:
speed_factor=float(normalized.speed_factor), speed_factor=float(normalized.speed_factor),
sample_steps=int(normalized.sample_steps), sample_steps=int(normalized.sample_steps),
media_type=normalized.media_type, media_type=normalized.media_type,
super_sampling=bool(normalized.super_sampling),
prepare_wall_ms=prepare_wall_ms, prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms, prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=loop, done_loop=loop,

View File

@ -149,6 +149,7 @@ class EngineBridgeDelegates:
speed_factor: float, speed_factor: float,
sample_steps: int, sample_steps: int,
media_type: str, media_type: str,
super_sampling: bool,
prepare_wall_ms: float, prepare_wall_ms: float,
prepare_profile_total_ms: float, prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None, done_loop: asyncio.AbstractEventLoop | None,
@ -161,6 +162,7 @@ class EngineBridgeDelegates:
speed_factor=speed_factor, speed_factor=speed_factor,
sample_steps=sample_steps, sample_steps=sample_steps,
media_type=media_type, media_type=media_type,
super_sampling=super_sampling,
prepare_wall_ms=prepare_wall_ms, prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms, prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=done_loop, done_loop=done_loop,

View File

@ -78,6 +78,7 @@ class EngineStageBridgeFacade:
speed_factor: float, speed_factor: float,
sample_steps: int, sample_steps: int,
media_type: str, media_type: str,
super_sampling: bool,
prepare_wall_ms: float, prepare_wall_ms: float,
prepare_profile_total_ms: float, prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None, done_loop: asyncio.AbstractEventLoop | None,
@ -90,6 +91,7 @@ class EngineStageBridgeFacade:
speed_factor=speed_factor, speed_factor=speed_factor,
sample_steps=sample_steps, sample_steps=sample_steps,
media_type=media_type, media_type=media_type,
super_sampling=super_sampling,
prepare_wall_ms=prepare_wall_ms, prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms, prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=done_loop, done_loop=done_loop,

View File

@ -104,6 +104,7 @@ class NormalizedEngineRequest:
temperature=self.temperature, temperature=self.temperature,
repetition_penalty=self.repetition_penalty, repetition_penalty=self.repetition_penalty,
early_stop_num=self.early_stop_num, early_stop_num=self.early_stop_num,
aux_ref_audio_paths=list(self.aux_ref_audio_paths or []),
ready_step=self.ready_step, ready_step=self.ready_step,
) )

View File

@ -303,6 +303,7 @@ class SchedulerPendingJob:
speed_factor: float speed_factor: float
sample_steps: int sample_steps: int
media_type: str media_type: str
super_sampling: bool = False
admission_wait_ms: float = 0.0 admission_wait_ms: float = 0.0
engine_policy_wait_ms: float = 0.0 engine_policy_wait_ms: float = 0.0
engine_dispatch_wait_ms: float = 0.0 engine_dispatch_wait_ms: float = 0.0

View File

@ -291,6 +291,7 @@ class EngineDispatchTask:
speed_factor: float speed_factor: float
sample_steps: int sample_steps: int
media_type: str media_type: str
super_sampling: bool
prepare_wall_ms: float prepare_wall_ms: float
prepare_profile_total_ms: float prepare_profile_total_ms: float
done_loop: asyncio.AbstractEventLoop | None done_loop: asyncio.AbstractEventLoop | None

View File

@ -113,6 +113,7 @@ class EngineStageCoordinator:
speed_factor: float, speed_factor: float,
sample_steps: int, sample_steps: int,
media_type: str, media_type: str,
super_sampling: bool,
prepare_wall_ms: float, prepare_wall_ms: float,
prepare_profile_total_ms: float, prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None, done_loop: asyncio.AbstractEventLoop | None,
@ -125,6 +126,7 @@ class EngineStageCoordinator:
speed_factor=speed_factor, speed_factor=speed_factor,
sample_steps=sample_steps, sample_steps=sample_steps,
media_type=media_type, media_type=media_type,
super_sampling=super_sampling,
prepare_wall_ms=prepare_wall_ms, prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms, prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=done_loop, done_loop=done_loop,

View File

@ -16,6 +16,7 @@ class EngineDispatchStageMixin:
speed_factor: float, speed_factor: float,
sample_steps: int, sample_steps: int,
media_type: str, media_type: str,
super_sampling: bool,
prepare_wall_ms: float, prepare_wall_ms: float,
prepare_profile_total_ms: float, prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None, done_loop: asyncio.AbstractEventLoop | None,
@ -29,6 +30,7 @@ class EngineDispatchStageMixin:
speed_factor=float(speed_factor), speed_factor=float(speed_factor),
sample_steps=int(sample_steps), sample_steps=int(sample_steps),
media_type=media_type, media_type=media_type,
super_sampling=bool(super_sampling),
prepare_wall_ms=float(prepare_wall_ms), prepare_wall_ms=float(prepare_wall_ms),
prepare_profile_total_ms=float(prepare_profile_total_ms), prepare_profile_total_ms=float(prepare_profile_total_ms),
done_loop=done_loop, done_loop=done_loop,
@ -66,6 +68,7 @@ class EngineDispatchStageMixin:
speed_factor=dispatch_task.speed_factor, speed_factor=dispatch_task.speed_factor,
sample_steps=dispatch_task.sample_steps, sample_steps=dispatch_task.sample_steps,
media_type=dispatch_task.media_type, media_type=dispatch_task.media_type,
super_sampling=dispatch_task.super_sampling,
prepare_wall_ms=dispatch_task.prepare_wall_ms, prepare_wall_ms=dispatch_task.prepare_wall_ms,
prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms,
done_loop=dispatch_task.done_loop, done_loop=dispatch_task.done_loop,

View File

@ -46,7 +46,7 @@ class UnifiedSchedulerWorker(
self.decode_backlog_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_DECODE_BACKLOG_MAX", "0"))) self.decode_backlog_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_DECODE_BACKLOG_MAX", "0")))
self.finalize_pending_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_FINALIZE_PENDING_MAX", "0"))) self.finalize_pending_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_FINALIZE_PENDING_MAX", "0")))
self.engine_decode_control_enabled = ( self.engine_decode_control_enabled = (
str(os.environ.get("GPTSOVITS_ENGINE_DRIVE_DECODE", "0")).strip().lower() in {"1", "true", "yes", "on"} str(os.environ.get("GPTSOVITS_ENGINE_DRIVE_DECODE", "1")).strip().lower() in {"1", "true", "yes", "on"}
) )
self.job_registry = SchedulerJobRegistry(self.condition) self.job_registry = SchedulerJobRegistry(self.condition)
self.worker_thread: threading.Thread | None = None self.worker_thread: threading.Thread | None = None

View File

@ -149,16 +149,25 @@ class WorkerFinalizeExecutor:
except Exception: except Exception:
pass pass
@staticmethod
def _collect_job_refer_specs(job: SchedulerPendingJob) -> List[tuple]:
refer_specs = [job.state.refer_spec]
refer_specs.extend(list(getattr(job.state, "aux_refer_specs", []) or []))
return refer_specs
def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]:
audio_fragment = self.tts.synthesize_audio_request_local( audio_fragment = self.tts.synthesize_audio_request_local(
semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0),
phones=job.state.phones.detach().clone().unsqueeze(0), phones=job.state.phones.detach().clone().unsqueeze(0),
prompt_semantic=job.state.prompt_semantic.detach().clone(), prompt_semantic=job.state.prompt_semantic.detach().clone(),
prompt_phones=job.state.prompt_phones.detach().clone(), prompt_phones=job.state.prompt_phones.detach().clone(),
refer_spec=( refer_spec=[
job.state.refer_spec[0].detach().clone(), (
None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), refer_spec_item[0].detach().clone(),
), None if refer_spec_item[1] is None else refer_spec_item[1].detach().clone(),
)
for refer_spec_item in self._collect_job_refer_specs(job)
],
raw_audio=job.state.raw_audio.detach().clone(), raw_audio=job.state.raw_audio.detach().clone(),
raw_sr=int(job.state.raw_sr), raw_sr=int(job.state.raw_sr),
speed=float(job.speed_factor), speed=float(job.speed_factor),
@ -172,7 +181,7 @@ class WorkerFinalizeExecutor:
speed_factor=float(job.speed_factor), speed_factor=float(job.speed_factor),
split_bucket=False, split_bucket=False,
fragment_interval=0.0, fragment_interval=0.0,
super_sampling=False, super_sampling=bool(job.super_sampling),
) )
def _synthesize_finished_audio_batch( def _synthesize_finished_audio_batch(
@ -185,11 +194,14 @@ class WorkerFinalizeExecutor:
speeds = [] speeds = []
sample_steps_list = [] sample_steps_list = []
for job, _ in jobs_and_items: for job, _ in jobs_and_items:
refer_spec_group = self._collect_job_refer_specs(job)
if len(refer_spec_group) != 1:
raise ValueError("batched finalize 暂不支持单请求多参考音频")
refer_specs.append( refer_specs.append(
( [(
job.state.refer_spec[0].detach().clone(), refer_spec_group[0][0].detach().clone(),
None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), None if refer_spec_group[0][1] is None else refer_spec_group[0][1].detach().clone(),
) )]
) )
speeds.append(float(job.speed_factor)) speeds.append(float(job.speed_factor))
sample_steps_list.append(int(job.sample_steps)) sample_steps_list.append(int(job.sample_steps))
@ -211,7 +223,7 @@ class WorkerFinalizeExecutor:
speed_factor=float(job.speed_factor), speed_factor=float(job.speed_factor),
split_bucket=False, split_bucket=False,
fragment_interval=0.0, fragment_interval=0.0,
super_sampling=False, super_sampling=bool(job.super_sampling),
) )
) )
return results return results
@ -224,9 +236,12 @@ class WorkerFinalizeExecutor:
return 0.0, [] return 0.0, []
self._sync_device() self._sync_device()
synth_start = time.perf_counter() synth_start = time.perf_counter()
if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: if (
job, item = jobs_and_items[0] len(jobs_and_items) == 1
batch_results = [self._synthesize_finished_audio(job, item)] or self.tts.configs.use_vocoder
or any(len(self._collect_job_refer_specs(job)) != 1 for job, _ in jobs_and_items)
):
batch_results = [self._synthesize_finished_audio(job, item) for job, item in jobs_and_items]
else: else:
batch_results = self._synthesize_finished_audio_batch(jobs_and_items) batch_results = self._synthesize_finished_audio_batch(jobs_and_items)
self._sync_device() self._sync_device()

View File

@ -78,6 +78,7 @@ class WorkerSubmitLifecycleMixin:
speed_factor: float, speed_factor: float,
sample_steps: int, sample_steps: int,
media_type: str, media_type: str,
super_sampling: bool,
prepare_wall_ms: float, prepare_wall_ms: float,
prepare_profile_total_ms: float, prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None = None, done_loop: asyncio.AbstractEventLoop | None = None,
@ -97,6 +98,7 @@ class WorkerSubmitLifecycleMixin:
speed_factor, speed_factor,
sample_steps, sample_steps,
media_type, media_type,
super_sampling,
prepare_wall_ms, prepare_wall_ms,
prepare_profile_total_ms, prepare_profile_total_ms,
done_loop, done_loop,
@ -172,6 +174,7 @@ class WorkerSubmitLifecycleMixin:
speed_factor: float, speed_factor: float,
sample_steps: int, sample_steps: int,
media_type: str, media_type: str,
super_sampling: bool,
prepare_wall_ms: float, prepare_wall_ms: float,
prepare_profile_total_ms: float, prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None = None, done_loop: asyncio.AbstractEventLoop | None = None,
@ -205,6 +208,7 @@ class WorkerSubmitLifecycleMixin:
speed_factor=float(speed_factor), speed_factor=float(speed_factor),
sample_steps=int(sample_steps), sample_steps=int(sample_steps),
media_type=media_type, media_type=media_type,
super_sampling=bool(super_sampling),
admission_wait_ms=float(admission_wait_ms), admission_wait_ms=float(admission_wait_ms),
engine_policy_wait_ms=float(engine_policy_wait_ms), engine_policy_wait_ms=float(engine_policy_wait_ms),
engine_dispatch_wait_ms=float(engine_dispatch_wait_ms), engine_dispatch_wait_ms=float(engine_dispatch_wait_ms),

View File

@ -39,8 +39,8 @@ POST:
"seed": -1, # int. random seed for reproducibility. "seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference. "parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model. "repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3. "sample_steps": 32, # int. 仅 v3/v4 vocoder 路径使用;当前 v2/v2ProPlus 主线可忽略。
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. "super_sampling": False, # bool. 仅 v3/v4 路径使用;不属于当前 v2/v2ProPlus 正式支持目标。
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
@ -79,7 +79,7 @@ endpoint: `/set_gpt_weights`
GET: GET:
``` ```
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1v3.ckpt
``` ```
RESP: RESP:
成功: 返回"success", http code 200 成功: 返回"success", http code 200
@ -92,7 +92,7 @@ endpoint: `/set_sovits_weights`
GET: GET:
``` ```
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth
``` ```
RESP: RESP:
@ -211,7 +211,7 @@ async def tts_handle(req: dict):
"parallel_infer": True, # bool. whether to use parallel inference. "parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model. "repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3. "sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. "super_sampling": False, # bool. only for v3/v4; not part of current v2/v2ProPlus mainline.
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)

View File

@ -39,8 +39,8 @@ POST:
"seed": -1, # int. random seed for reproducibility. "seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference. "parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model. "repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3. "sample_steps": 32, # int. 仅 v3/v4 vocoder 路径使用;当前 v2/v2ProPlus 主线可忽略。
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. "super_sampling": False, # bool. 仅 v3/v4 路径使用;不属于当前 v2/v2ProPlus 正式支持目标。
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
@ -79,7 +79,7 @@ endpoint: `/set_gpt_weights`
GET: GET:
``` ```
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1v3.ckpt
``` ```
RESP: RESP:
成功: 返回"success", http code 200 成功: 返回"success", http code 200
@ -92,7 +92,7 @@ endpoint: `/set_sovits_weights`
GET: GET:
``` ```
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth
``` ```
RESP: RESP:
@ -280,7 +280,7 @@ async def tts_handle(req: dict):
"parallel_infer": True, # bool. whether to use parallel inference. "parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model. "repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3. "sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. "super_sampling": False, # bool. only for v3/v4; not part of current v2/v2ProPlus mainline.
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)