mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-24 12:58:11 +08:00
Enhance TTS API with improved request handling and asynchronous processing
Refactor api_v2.py and api_v3.py to update sampling parameters and weight paths for better clarity and support for v3/v4 vocoders. Introduce new methods in PrepareCoordinator for handling empty text features and improve profiling capabilities. Additionally, update unified engine components to streamline audio processing and state management, enhancing overall performance and maintainability of the TTS system.
This commit is contained in:
parent
d453a8e47c
commit
6a822b28c3
@ -323,7 +323,7 @@ class TTS_Config:
|
||||
assert isinstance(configs, dict)
|
||||
configs_ = deepcopy(self.default_configs)
|
||||
configs_.update(configs)
|
||||
self.configs: dict = configs_.get("custom", configs_["v2"])
|
||||
self.configs: dict = configs_.get("custom", configs_["v2ProPlus"])
|
||||
self.default_configs = deepcopy(configs_)
|
||||
|
||||
self.device = self.configs.get("device", torch.device("cpu"))
|
||||
@ -1872,19 +1872,19 @@ class TTS:
|
||||
self.init_sr_model()
|
||||
if not self.sr_model_not_exist:
|
||||
audio, sr = self.sr_model(audio.unsqueeze(0), sr)
|
||||
max_audio = np.abs(audio).max()
|
||||
if isinstance(audio, torch.Tensor):
|
||||
max_audio = float(torch.abs(audio).max().item())
|
||||
else:
|
||||
max_audio = float(np.abs(audio).max())
|
||||
if max_audio > 1:
|
||||
audio /= max_audio
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
t2 = time.perf_counter()
|
||||
print(f"超采样用时:{t2 - t1:.3f}s")
|
||||
if isinstance(audio, torch.Tensor):
|
||||
audio = audio.detach().float().cpu().numpy()
|
||||
else:
|
||||
# audio = audio.float() * 32768
|
||||
# audio = audio.to(dtype=torch.int16).clamp(-32768, 32767).cpu().numpy()
|
||||
|
||||
audio = audio.cpu().numpy()
|
||||
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
audio = np.asarray(audio)
|
||||
audio = (audio.reshape(-1) * 32768).astype(np.int16)
|
||||
|
||||
|
||||
# try:
|
||||
@ -2036,20 +2036,23 @@ class TTS:
|
||||
phones: torch.Tensor,
|
||||
prompt_semantic: torch.Tensor,
|
||||
prompt_phones: torch.Tensor,
|
||||
refer_spec: tuple,
|
||||
refer_spec: tuple | List[tuple],
|
||||
raw_audio: torch.Tensor,
|
||||
raw_sr: int,
|
||||
speed: float = 1.0,
|
||||
sample_steps: int = 32,
|
||||
):
|
||||
refer_audio_spec, audio_tensor = refer_spec
|
||||
refer_specs = list(refer_spec) if isinstance(refer_spec, list) else [refer_spec]
|
||||
refer_audio_spec, audio_tensor = refer_specs[0]
|
||||
if not self.configs.use_vocoder:
|
||||
refer_audio_spec_list = [refer_audio_spec.to(dtype=self.precision, device=self.configs.device)]
|
||||
refer_audio_spec_list = [item[0].to(dtype=self.precision, device=self.configs.device) for item in refer_specs]
|
||||
sv_emb = None
|
||||
if self.is_v2pro:
|
||||
if audio_tensor is None:
|
||||
raise ValueError(i18n("v2Pro request-local synthesis 缺少 16k 参考音频"))
|
||||
sv_emb = self.sv_model.compute_embedding3(audio_tensor).to(self.configs.device)
|
||||
sv_emb = []
|
||||
for _, audio_tensor_item in refer_specs:
|
||||
if audio_tensor_item is None:
|
||||
raise ValueError(i18n("v2Pro request-local synthesis 缺少 16k 参考音频"))
|
||||
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor_item).to(self.configs.device))
|
||||
return self.vits_model.decode(
|
||||
semantic_tokens,
|
||||
phones,
|
||||
@ -2075,7 +2078,7 @@ class TTS:
|
||||
self,
|
||||
semantic_tokens_list: List[torch.Tensor],
|
||||
phones_list: List[torch.Tensor],
|
||||
refer_specs: List[tuple],
|
||||
refer_specs: List[tuple | List[tuple]],
|
||||
speeds: List[float] | None = None,
|
||||
sample_steps_list: List[int] | None = None,
|
||||
) -> List[torch.Tensor]:
|
||||
@ -2118,7 +2121,11 @@ class TTS:
|
||||
semantic_lengths.append(semantic_len)
|
||||
phone_lengths.append(phone_len)
|
||||
|
||||
refer_audio_spec, audio_tensor = refer_specs[batch_index]
|
||||
refer_spec_item = refer_specs[batch_index]
|
||||
refer_spec_group = list(refer_spec_item) if isinstance(refer_spec_item, list) else [refer_spec_item]
|
||||
if len(refer_spec_group) != 1:
|
||||
raise ValueError("batched request-local synthesis 暂不支持单请求多参考音频")
|
||||
refer_audio_spec, audio_tensor = refer_spec_group[0]
|
||||
refer_audio_specs.append(refer_audio_spec.to(dtype=self.precision, device=device))
|
||||
if self.is_v2pro:
|
||||
if audio_tensor is None:
|
||||
|
||||
@ -12,6 +12,7 @@ from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import (
|
||||
PreparedTextFeatures,
|
||||
SchedulerRequestSpec,
|
||||
T2SRequestState,
|
||||
build_empty_text_features,
|
||||
build_request_state_from_parts,
|
||||
normalize_sentence,
|
||||
)
|
||||
@ -118,6 +119,21 @@ class PrepareCoordinator:
|
||||
def _prepare_text_cpu(self, text: str, language: str):
|
||||
return self.tts.prepare_text_segments(text, language)
|
||||
|
||||
@staticmethod
|
||||
def _build_empty_text_features_like(reference: PreparedTextFeatures | None = None) -> PreparedTextFeatures:
|
||||
feature_dim = 1024
|
||||
dtype = None
|
||||
if reference is not None:
|
||||
try:
|
||||
feature_dim = int(reference.bert_features.shape[0])
|
||||
dtype = reference.bert_features.dtype
|
||||
except Exception:
|
||||
pass
|
||||
return build_empty_text_features(
|
||||
feature_dim=int(feature_dim),
|
||||
dtype=(dtype if dtype is not None else None) or __import__("torch").float32,
|
||||
)
|
||||
|
||||
def _build_text_features(self, prepared_segments, language: str, cpu_run_ms: float) -> PreparedTextFeatures:
|
||||
profile: Dict[str, float] = {"cpu_preprocess_ms": float(cpu_run_ms)}
|
||||
branch_start = time.perf_counter()
|
||||
@ -139,6 +155,9 @@ class PrepareCoordinator:
|
||||
return await loop.run_in_executor(executor, self._run_profiled, fn, float(submit_at), *args)
|
||||
|
||||
async def _run_text_cpu_stage(self, text: str, language: str) -> ProfiledResult:
|
||||
if text in [None, ""]:
|
||||
submit_at = time.perf_counter()
|
||||
return ProfiledResult(result=[], submit_at=submit_at, started_at=submit_at, finished_at=submit_at)
|
||||
executor = getattr(self.tts, "prepare_text_cpu_executor", None)
|
||||
if executor is None:
|
||||
submit_at = time.perf_counter()
|
||||
@ -164,19 +183,71 @@ class PrepareCoordinator:
|
||||
prompt_cpu_run_ms: float,
|
||||
target_cpu_run_ms: float,
|
||||
) -> tuple[ProfiledResult, ProfiledResult]:
|
||||
prompt_is_empty = len(prompt_segments or []) == 0
|
||||
if self.text_feature_executor is not None:
|
||||
prompt_feature_task = asyncio.create_task(
|
||||
self._run_text_feature_stage(prompt_segments, None, prompt_cpu_run_ms)
|
||||
target_feature_task = asyncio.create_task(self._run_text_feature_stage(target_segments, None, target_cpu_run_ms))
|
||||
if not prompt_is_empty:
|
||||
prompt_feature_task = asyncio.create_task(self._run_text_feature_stage(prompt_segments, None, prompt_cpu_run_ms))
|
||||
return await asyncio.gather(prompt_feature_task, target_feature_task)
|
||||
target_profiled = await target_feature_task
|
||||
submit_at = time.perf_counter()
|
||||
prompt_profiled = ProfiledResult(
|
||||
result=self._build_empty_text_features_like(target_profiled.result),
|
||||
submit_at=float(submit_at),
|
||||
started_at=float(submit_at),
|
||||
finished_at=float(submit_at),
|
||||
)
|
||||
target_feature_task = asyncio.create_task(
|
||||
self._run_text_feature_stage(target_segments, None, target_cpu_run_ms)
|
||||
)
|
||||
return await asyncio.gather(prompt_feature_task, target_feature_task)
|
||||
return prompt_profiled, target_profiled
|
||||
|
||||
prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)}
|
||||
target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)}
|
||||
submit_at = time.perf_counter()
|
||||
started_at = float(submit_at)
|
||||
if prompt_is_empty:
|
||||
target_result_raw = await self.tts.build_text_features_from_segments_async(
|
||||
target_segments,
|
||||
profile=target_profile,
|
||||
)
|
||||
prompt_result = self._build_empty_text_features_like(
|
||||
PreparedTextFeatures(
|
||||
phones=target_result_raw[0],
|
||||
bert_features=target_result_raw[1],
|
||||
norm_text=target_result_raw[2],
|
||||
profile=target_profile,
|
||||
total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)),
|
||||
cpu_preprocess_ms=float(target_cpu_run_ms),
|
||||
)
|
||||
)
|
||||
finished_at = time.perf_counter()
|
||||
prompt_profiled = ProfiledResult(
|
||||
result=prompt_result,
|
||||
submit_at=float(submit_at),
|
||||
started_at=float(submit_at),
|
||||
finished_at=float(submit_at),
|
||||
)
|
||||
target_result = PreparedTextFeatures(
|
||||
phones=target_result_raw[0],
|
||||
bert_features=target_result_raw[1],
|
||||
norm_text=target_result_raw[2],
|
||||
profile=target_profile,
|
||||
total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)),
|
||||
cpu_preprocess_ms=float(target_cpu_run_ms),
|
||||
)
|
||||
target_profiled = ProfiledResult(
|
||||
result=target_result,
|
||||
submit_at=float(submit_at),
|
||||
started_at=started_at,
|
||||
finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0),
|
||||
)
|
||||
if finished_at > target_profiled.finished_at:
|
||||
target_result.profile["bert_total_ms"] = max(
|
||||
self._estimate_text_feature_run_ms(target_profile),
|
||||
(finished_at - submit_at) * 1000.0,
|
||||
)
|
||||
else:
|
||||
target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile)
|
||||
return prompt_profiled, target_profiled
|
||||
|
||||
prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)}
|
||||
prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async(
|
||||
prompt_segments,
|
||||
target_segments,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
@ -35,6 +36,7 @@ class SchedulerRequestSpec:
|
||||
temperature: float
|
||||
repetition_penalty: float
|
||||
early_stop_num: int
|
||||
aux_ref_audio_paths: List[str] = field(default_factory=list)
|
||||
ready_step: int = 0
|
||||
|
||||
|
||||
@ -54,6 +56,7 @@ class T2SRequestState:
|
||||
all_bert_features: torch.Tensor
|
||||
prompt_semantic: torch.LongTensor
|
||||
refer_spec: Tuple[torch.Tensor, Optional[torch.Tensor]]
|
||||
aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
|
||||
raw_audio: torch.Tensor
|
||||
raw_sr: int
|
||||
top_k: int
|
||||
@ -113,6 +116,21 @@ class PreparedTextFeatures:
|
||||
cpu_preprocess_ms: float
|
||||
|
||||
|
||||
def build_empty_text_features(
|
||||
*,
|
||||
feature_dim: int = 1024,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> PreparedTextFeatures:
|
||||
return PreparedTextFeatures(
|
||||
phones=[],
|
||||
bert_features=torch.empty((int(feature_dim), 0), dtype=dtype),
|
||||
norm_text="",
|
||||
profile={"cpu_preprocess_ms": 0.0, "bert_total_ms": 0.0},
|
||||
total_ms=0.0,
|
||||
cpu_preprocess_ms=0.0,
|
||||
)
|
||||
|
||||
|
||||
def normalize_sentence(text: str, language: str) -> str:
|
||||
text = text.strip("\n").strip()
|
||||
if not text:
|
||||
@ -171,6 +189,14 @@ def build_request_state_from_parts(
|
||||
bundle_profile = ref_audio_bundle.get("profile", {})
|
||||
prompt_semantic = ref_audio_bundle["prompt_semantic"].long()
|
||||
spec_audio, audio_16k = ref_audio_bundle["refer_spec"]
|
||||
aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] = []
|
||||
for aux_ref_audio_path in list(getattr(spec, "aux_ref_audio_paths", []) or []):
|
||||
if aux_ref_audio_path in [None, ""]:
|
||||
continue
|
||||
if not os.path.exists(str(aux_ref_audio_path)):
|
||||
continue
|
||||
aux_spec_audio, aux_audio_16k, _, _ = tts.extract_ref_spec(str(aux_ref_audio_path))
|
||||
aux_refer_specs.append((aux_spec_audio, aux_audio_16k))
|
||||
raw_audio = ref_audio_bundle["raw_audio"]
|
||||
raw_sr = int(ref_audio_bundle["raw_sr"])
|
||||
prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms))
|
||||
@ -182,9 +208,9 @@ def build_request_state_from_parts(
|
||||
phones_tensor = torch.LongTensor(target_result.phones).to(tts.configs.device)
|
||||
prompt_phones_tensor = torch.LongTensor(prompt_result.phones).to(tts.configs.device)
|
||||
all_phones = torch.LongTensor(prompt_result.phones + target_result.phones).to(tts.configs.device)
|
||||
all_bert_features = torch.cat([prompt_result.bert_features, target_result.bert_features], dim=1).to(
|
||||
dtype=tts.precision, device=tts.configs.device
|
||||
)
|
||||
prompt_bert_features = prompt_result.bert_features.to(dtype=tts.precision, device=tts.configs.device)
|
||||
target_bert_features = target_result.bert_features.to(dtype=tts.precision, device=tts.configs.device)
|
||||
all_bert_features = torch.cat([prompt_bert_features, target_bert_features], dim=1)
|
||||
_sync_device(device)
|
||||
tensorize_ms = (time.perf_counter() - tensorize_start) * 1000.0
|
||||
|
||||
@ -280,6 +306,7 @@ def build_request_state_from_parts(
|
||||
all_bert_features=all_bert_features,
|
||||
prompt_semantic=prompt_semantic,
|
||||
refer_spec=(spec_audio, audio_16k),
|
||||
aux_refer_specs=aux_refer_specs,
|
||||
raw_audio=raw_audio,
|
||||
raw_sr=raw_sr,
|
||||
top_k=spec.top_k,
|
||||
@ -301,10 +328,16 @@ def prepare_request_state(
|
||||
prepare_sync_start = time.perf_counter()
|
||||
prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang)
|
||||
text = spec.text.strip("\n")
|
||||
prompt_result = prepare_text_features(tts, prompt_text, spec.prompt_lang)
|
||||
target_result = prepare_text_features(tts, text, spec.text_lang)
|
||||
if target_result.phones is None:
|
||||
raise ValueError(f"{spec.request_id} text preprocessing returned no phones")
|
||||
if prompt_text in [None, ""]:
|
||||
prompt_result = build_empty_text_features(
|
||||
feature_dim=int(target_result.bert_features.shape[0]),
|
||||
dtype=target_result.bert_features.dtype,
|
||||
)
|
||||
else:
|
||||
prompt_result = prepare_text_features(tts, prompt_text, spec.prompt_lang)
|
||||
ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path))
|
||||
return build_request_state_from_parts(
|
||||
tts=tts,
|
||||
|
||||
@ -119,6 +119,7 @@ class EngineApiFacade:
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
super_sampling: bool,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None,
|
||||
@ -131,6 +132,7 @@ class EngineApiFacade:
|
||||
speed_factor=speed_factor,
|
||||
sample_steps=sample_steps,
|
||||
media_type=media_type,
|
||||
super_sampling=super_sampling,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=done_loop,
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
@ -122,6 +124,173 @@ class EngineApiDirectFlow:
|
||||
payload["response_streaming"] = False
|
||||
return self.api._normalize_engine_request(payload, error_prefix="segment request 参数非法: ")
|
||||
|
||||
async def _execute_single_segment_scheduler_job(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
segment_request: NormalizedEngineRequest,
|
||||
) -> tuple[SchedulerPendingJob, Dict[str, Any]]:
|
||||
spec = self.api._build_scheduler_submit_spec(segment_request)
|
||||
state, prepare_exec_started_at, prepare_exec_finished_at = await self.api._prepare_state_via_engine_gpu_queue(
|
||||
spec=spec,
|
||||
prepare_submit_at=time.perf_counter(),
|
||||
engine_request_id=None,
|
||||
)
|
||||
prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0)
|
||||
prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms))
|
||||
loop = asyncio.get_running_loop()
|
||||
done_future = loop.create_future()
|
||||
await self.api._enqueue_prepared_state_for_dispatch(
|
||||
state=state,
|
||||
speed_factor=float(normalized.speed_factor),
|
||||
sample_steps=int(normalized.sample_steps),
|
||||
media_type=normalized.media_type,
|
||||
super_sampling=bool(normalized.super_sampling),
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=None,
|
||||
timeout_sec=normalized.timeout_sec,
|
||||
)
|
||||
timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)
|
||||
job: SchedulerPendingJob = await asyncio.wait_for(done_future, timeout=timeout_sec)
|
||||
return job, {
|
||||
"request_id": spec.request_id,
|
||||
"prepare_wall_ms": prepare_wall_ms,
|
||||
"prepare_profile_total_ms": prepare_profile_total_ms,
|
||||
"prepare_profile": dict(state.prepare_profile),
|
||||
}
|
||||
|
||||
def _iter_scheduler_direct_tts_bytes(self, normalized: NormalizedEngineRequest) -> Generator[bytes, None, None]:
|
||||
request_start = time.perf_counter()
|
||||
request_id = normalized.request_id
|
||||
media_type = normalized.media_type
|
||||
segment_texts = self._segment_direct_text(normalized)
|
||||
if not segment_texts:
|
||||
raise ValueError("text preprocessing returned no valid segments")
|
||||
chunk_queue: queue.Queue[object] = queue.Queue(maxsize=8)
|
||||
done_marker = object()
|
||||
|
||||
async def _produce_chunks() -> None:
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.CPU_PREPARING,
|
||||
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)},
|
||||
)
|
||||
sample_rate: int | None = None
|
||||
current_media_type = media_type
|
||||
chunk_count = 0
|
||||
stream_total_bytes = 0
|
||||
first_chunk_ms: float | None = None
|
||||
prepare_profiles: List[Dict[str, Any]] = []
|
||||
worker_profiles: List[Dict[str, Any]] = []
|
||||
try:
|
||||
for segment_index, segment_text in enumerate(segment_texts):
|
||||
segment_request = self._build_segment_request(
|
||||
normalized,
|
||||
request_id=f"{request_id}_seg_{segment_index:03d}",
|
||||
text=segment_text,
|
||||
)
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.READY_FOR_PREFILL,
|
||||
{
|
||||
"backend": "scheduler_v1_direct",
|
||||
"backend_mode": "scheduler_v1_direct",
|
||||
"segment_index": segment_index,
|
||||
"segment_count": len(segment_texts),
|
||||
},
|
||||
)
|
||||
job, prepare_profile = await self._execute_single_segment_scheduler_job(
|
||||
normalized,
|
||||
segment_request=segment_request,
|
||||
)
|
||||
prepare_profiles.append(prepare_profile)
|
||||
if job.error is not None:
|
||||
raise RuntimeError(job.error)
|
||||
if job.audio_data is None or job.sample_rate is None or job.result is None:
|
||||
raise RuntimeError(f"{job.request_id} finished without audio result")
|
||||
worker_profiles.append(dict(job.result))
|
||||
if sample_rate is None:
|
||||
sample_rate = int(job.sample_rate)
|
||||
first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0)
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.STREAMING,
|
||||
{
|
||||
"backend": "scheduler_v1_direct",
|
||||
"backend_mode": "scheduler_v1_direct",
|
||||
"sample_rate": int(sample_rate),
|
||||
},
|
||||
)
|
||||
if media_type == "wav":
|
||||
header = wave_header_chunk(sample_rate=int(sample_rate))
|
||||
chunk_count += 1
|
||||
stream_total_bytes += len(header)
|
||||
chunk_queue.put(header)
|
||||
current_media_type = "raw"
|
||||
packed_chunk = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), current_media_type).getvalue()
|
||||
chunk_count += 1
|
||||
stream_total_bytes += len(packed_chunk)
|
||||
chunk_queue.put(packed_chunk)
|
||||
if segment_index + 1 < len(segment_texts):
|
||||
silence_samples = int(float(normalized.fragment_interval) * float(job.sample_rate))
|
||||
if silence_samples > 0:
|
||||
silence_chunk = np.zeros(silence_samples, dtype=np.int16)
|
||||
packed_silence = pack_audio(
|
||||
BytesIO(), silence_chunk, int(job.sample_rate), current_media_type
|
||||
).getvalue()
|
||||
chunk_count += 1
|
||||
stream_total_bytes += len(packed_silence)
|
||||
chunk_queue.put(packed_silence)
|
||||
except Exception as exc:
|
||||
self.api._fail_request_state(request_id, str(exc))
|
||||
chunk_queue.put(exc)
|
||||
else:
|
||||
self.api._merge_request_state_profile(
|
||||
request_id,
|
||||
{
|
||||
"prepare_aggregate": self.api._aggregate_numeric_dicts(
|
||||
[item["prepare_profile"] for item in prepare_profiles]
|
||||
),
|
||||
"engine_policy_wait_ms": sum(
|
||||
float(item.get("engine_policy_wait_ms", 0.0)) for item in worker_profiles
|
||||
),
|
||||
"engine_dispatch_wait_ms": sum(
|
||||
float(item.get("engine_dispatch_wait_ms", 0.0)) for item in worker_profiles
|
||||
),
|
||||
},
|
||||
)
|
||||
direct_profile = self.api._build_direct_scheduler_profile(
|
||||
backend="scheduler_v1_direct",
|
||||
request_start=request_start,
|
||||
response_ready_at=time.perf_counter(),
|
||||
audio_bytes=stream_total_bytes,
|
||||
sample_rate=int(sample_rate or 0),
|
||||
segment_texts=segment_texts,
|
||||
prepare_profiles=prepare_profiles,
|
||||
worker_profiles=worker_profiles,
|
||||
pack_ms=0.0,
|
||||
response_overhead_ms=0.0,
|
||||
)
|
||||
self.api._complete_request_state(
|
||||
request_id,
|
||||
dict(direct_profile, streaming_completed=True, first_chunk_ms=first_chunk_ms),
|
||||
)
|
||||
finally:
|
||||
chunk_queue.put(done_marker)
|
||||
|
||||
producer_thread = threading.Thread(target=lambda: asyncio.run(_produce_chunks()), daemon=True)
|
||||
producer_thread.start()
|
||||
while True:
|
||||
item = chunk_queue.get()
|
||||
if item is done_marker:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution:
|
||||
request_start = time.perf_counter()
|
||||
request_id = normalized.request_id
|
||||
@ -129,63 +298,48 @@ class EngineApiDirectFlow:
|
||||
segment_texts = self._segment_direct_text(normalized)
|
||||
if not segment_texts:
|
||||
raise ValueError("text preprocessing returned no valid segments")
|
||||
if normalized.response_streaming:
|
||||
return DirectTTSExecution(
|
||||
media_type=media_type,
|
||||
streaming=True,
|
||||
audio_generator=self._iter_scheduler_direct_tts_bytes(normalized),
|
||||
request_id=request_id,
|
||||
)
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.CPU_PREPARING,
|
||||
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)},
|
||||
)
|
||||
segment_specs = []
|
||||
for segment_index, segment_text in enumerate(segment_texts):
|
||||
segment_request = self._build_segment_request(
|
||||
segment_requests = [
|
||||
self._build_segment_request(
|
||||
normalized,
|
||||
request_id=f"{request_id}_seg_{segment_index:03d}",
|
||||
text=segment_text,
|
||||
)
|
||||
segment_specs.append(self.api._build_scheduler_submit_spec(segment_request))
|
||||
|
||||
prepared_items = await asyncio.gather(
|
||||
*[
|
||||
self.api._prepare_state_via_engine_gpu_queue(
|
||||
spec=spec,
|
||||
prepare_submit_at=time.perf_counter(),
|
||||
engine_request_id=None,
|
||||
)
|
||||
for spec in segment_specs
|
||||
]
|
||||
)
|
||||
for segment_index, segment_text in enumerate(segment_texts)
|
||||
]
|
||||
prepare_profiles: List[Dict[str, Any]] = []
|
||||
loop = asyncio.get_running_loop()
|
||||
done_futures: List[asyncio.Future] = []
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.READY_FOR_PREFILL,
|
||||
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_specs)},
|
||||
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_requests)},
|
||||
)
|
||||
for spec, (state, prepare_exec_started_at, prepare_exec_finished_at) in zip(segment_specs, prepared_items):
|
||||
prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0)
|
||||
prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms))
|
||||
prepare_profiles.append(
|
||||
{
|
||||
"request_id": spec.request_id,
|
||||
"prepare_wall_ms": prepare_wall_ms,
|
||||
"prepare_profile_total_ms": prepare_profile_total_ms,
|
||||
"prepare_profile": dict(state.prepare_profile),
|
||||
}
|
||||
)
|
||||
prepared_items = await asyncio.gather(
|
||||
*[
|
||||
self._execute_single_segment_scheduler_job(
|
||||
normalized,
|
||||
segment_request=segment_request,
|
||||
)
|
||||
for segment_request in segment_requests
|
||||
]
|
||||
)
|
||||
for job, prepare_profile in prepared_items:
|
||||
prepare_profiles.append(prepare_profile)
|
||||
done_future = loop.create_future()
|
||||
done_future.set_result(job)
|
||||
done_futures.append(done_future)
|
||||
await self.api._enqueue_prepared_state_for_dispatch(
|
||||
state=state,
|
||||
speed_factor=float(normalized.speed_factor),
|
||||
sample_steps=int(normalized.sample_steps),
|
||||
media_type=media_type,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=None,
|
||||
timeout_sec=normalized.timeout_sec,
|
||||
)
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.ACTIVE_DECODE,
|
||||
|
||||
@ -122,16 +122,6 @@ def is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool:
|
||||
|
||||
|
||||
def select_direct_backend(normalized: NormalizedEngineRequest) -> Tuple[str, str | None]:
|
||||
if normalized.response_streaming:
|
||||
if normalized.return_fragment or normalized.fixed_length_chunk:
|
||||
return "legacy_direct_fragment", "fragment_streaming_mode"
|
||||
return "legacy_direct_streaming", "streaming_mode"
|
||||
if is_aux_ref_enabled(normalized.aux_ref_audio_paths):
|
||||
return "legacy_direct_aux_ref", "aux_ref_audio_paths"
|
||||
if normalized.super_sampling:
|
||||
return "legacy_direct_super_sampling", "super_sampling"
|
||||
if normalized.prompt_text in [None, ""]:
|
||||
return "legacy_direct_missing_prompt", "missing_prompt_text"
|
||||
return "scheduler_v1_direct", None
|
||||
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ import uuid
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState, run_scheduler_continuous
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, set_scheduler_seed
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerSubmitExecution
|
||||
|
||||
@ -67,39 +67,58 @@ class EngineApiSchedulerFlow:
|
||||
async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
|
||||
request_start = time.perf_counter()
|
||||
set_scheduler_seed(seed)
|
||||
specs = self._build_scheduler_request_specs(request_items)
|
||||
request_ids = [spec.request_id for spec in specs]
|
||||
for spec in specs:
|
||||
normalized_requests: List[NormalizedEngineRequest] = []
|
||||
for index, payload in enumerate(request_items):
|
||||
normalized_requests.append(
|
||||
self.api._normalize_engine_request(
|
||||
payload,
|
||||
request_id=str(payload.get("request_id") or f"req_{index:03d}"),
|
||||
error_prefix=f"request[{index}] 参数非法: ",
|
||||
)
|
||||
)
|
||||
specs = [normalized.to_scheduler_spec() for normalized in normalized_requests]
|
||||
request_ids = [normalized.request_id for normalized in normalized_requests]
|
||||
for normalized, spec in zip(normalized_requests, specs):
|
||||
self.api._register_request_state(
|
||||
request_id=spec.request_id,
|
||||
request_id=normalized.request_id,
|
||||
api_mode="scheduler_debug",
|
||||
backend="scheduler_debug",
|
||||
media_type="wav",
|
||||
media_type=normalized.media_type,
|
||||
response_streaming=False,
|
||||
meta={
|
||||
"text_len": len(spec.text),
|
||||
"prompt_text_len": len(spec.prompt_text),
|
||||
"text_lang": spec.text_lang,
|
||||
"prompt_lang": spec.prompt_lang,
|
||||
"ref_audio_path": str(spec.ref_audio_path),
|
||||
"ready_step": int(spec.ready_step),
|
||||
},
|
||||
meta=self.api._build_request_meta(normalized.to_payload()),
|
||||
)
|
||||
self.api._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"})
|
||||
self.api._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, None)
|
||||
self.api._update_request_state(normalized.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"})
|
||||
self.api._update_request_state(normalized.request_id, EngineStatus.CPU_PREPARING, None)
|
||||
prepare_started_at = time.perf_counter()
|
||||
original_worker_max_steps = int(self.api.scheduler_worker.max_steps)
|
||||
original_decode_max_steps = int(self.api.scheduler_worker.decode_executor.max_steps)
|
||||
try:
|
||||
states = await self.api.scheduler_worker.prepare_states_batch_async(specs)
|
||||
self.api.scheduler_worker.max_steps = int(max_steps)
|
||||
self.api.scheduler_worker.decode_executor.max_steps = int(max_steps)
|
||||
prepared_payloads = await asyncio.gather(
|
||||
*[
|
||||
self.api._prepare_state_via_engine_gpu_queue(
|
||||
spec=spec,
|
||||
prepare_submit_at=time.perf_counter(),
|
||||
engine_request_id=normalized.request_id,
|
||||
)
|
||||
for normalized, spec in zip(normalized_requests, specs)
|
||||
]
|
||||
)
|
||||
except Exception as exc:
|
||||
for request_id in request_ids:
|
||||
self.api._fail_request_state(request_id, str(exc))
|
||||
raise
|
||||
finally:
|
||||
self.api.scheduler_worker.max_steps = int(original_worker_max_steps)
|
||||
self.api.scheduler_worker.decode_executor.max_steps = int(original_decode_max_steps)
|
||||
prepare_finished_at = time.perf_counter()
|
||||
prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0)
|
||||
states = [payload[0] for payload in prepared_payloads]
|
||||
for state in states:
|
||||
self.api._update_request_state(
|
||||
state.request_id,
|
||||
EngineStatus.ACTIVE_DECODE,
|
||||
EngineStatus.READY_FOR_PREFILL,
|
||||
{
|
||||
"prepare_profile": dict(state.prepare_profile),
|
||||
"norm_text": state.norm_text,
|
||||
@ -108,7 +127,27 @@ class EngineApiSchedulerFlow:
|
||||
)
|
||||
decode_started_at = time.perf_counter()
|
||||
try:
|
||||
finished = run_scheduler_continuous(self.api.tts.t2s_model.model, states, max_steps=int(max_steps))
|
||||
loop = asyncio.get_running_loop()
|
||||
done_futures: List[asyncio.Future] = []
|
||||
for normalized, state in zip(normalized_requests, states):
|
||||
done_future = loop.create_future()
|
||||
done_futures.append(done_future)
|
||||
await self.api._enqueue_prepared_state_for_dispatch(
|
||||
state=state,
|
||||
speed_factor=float(normalized.speed_factor),
|
||||
sample_steps=int(normalized.sample_steps),
|
||||
media_type=normalized.media_type,
|
||||
super_sampling=bool(normalized.super_sampling),
|
||||
prepare_wall_ms=float(state.prepare_profile.get("wall_total_ms", 0.0)),
|
||||
prepare_profile_total_ms=float(state.prepare_profile.get("wall_total_ms", 0.0)),
|
||||
done_loop=loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=normalized.request_id,
|
||||
timeout_sec=normalized.timeout_sec,
|
||||
)
|
||||
timeout_candidates = [float(item.timeout_sec) for item in normalized_requests if item.timeout_sec not in [None, ""]]
|
||||
timeout_sec = max(timeout_candidates) if timeout_candidates else 60.0
|
||||
jobs = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=float(timeout_sec)))
|
||||
except Exception as exc:
|
||||
for request_id in request_ids:
|
||||
self.api._fail_request_state(request_id, str(exc))
|
||||
@ -116,46 +155,63 @@ class EngineApiSchedulerFlow:
|
||||
decode_finished_at = time.perf_counter()
|
||||
decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0)
|
||||
request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0)
|
||||
finished_map = {item.request_id: item for item in finished}
|
||||
request_profiles: List[Dict[str, Any]] = []
|
||||
for state in states:
|
||||
item = finished_map.get(state.request_id)
|
||||
if item is None:
|
||||
finished: List[Dict[str, Any]] = []
|
||||
finish_reason_counts: Dict[str, int] = {}
|
||||
total_semantic_len = 0
|
||||
for state, job in zip(states, jobs):
|
||||
if job.error is not None:
|
||||
self.api._fail_request_state(state.request_id, str(job.error))
|
||||
raise RuntimeError(str(job.error))
|
||||
if job.result is None:
|
||||
self.api._fail_request_state(state.request_id, "scheduler_debug finished without result")
|
||||
continue
|
||||
request_profile = self.api._build_scheduler_debug_request_profile(
|
||||
state=state,
|
||||
item=item,
|
||||
batch_request_count=len(states),
|
||||
prepare_batch_wall_ms=prepare_batch_wall_ms,
|
||||
decode_batch_wall_ms=decode_batch_wall_ms,
|
||||
batch_request_total_ms=request_total_ms,
|
||||
)
|
||||
request_profiles.append(
|
||||
raise RuntimeError(f"{state.request_id} finished without result")
|
||||
job_result = dict(job.result)
|
||||
request_profile = {
|
||||
**job_result,
|
||||
"backend": "scheduler_debug",
|
||||
"backend_mode": "scheduler_debug",
|
||||
"batch_request_count": int(len(states)),
|
||||
"batch_prepare_wall_ms": float(prepare_batch_wall_ms),
|
||||
"batch_decode_wall_ms": float(decode_batch_wall_ms),
|
||||
"batch_request_total_ms": float(request_total_ms),
|
||||
"prepare_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
|
||||
"prepare_wall_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
|
||||
"prepare_profile_total_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
|
||||
"prepare_profile": dict(state.prepare_profile),
|
||||
"norm_text": state.norm_text,
|
||||
"norm_prompt_text": state.norm_prompt_text,
|
||||
}
|
||||
request_profiles.append({"request_id": state.request_id, "profile": dict(request_profile)})
|
||||
self.api._merge_request_state_profile(state.request_id, request_profile)
|
||||
semantic_len = int(job_result.get("semantic_len", 0))
|
||||
finish_reason = str(job_result.get("finish_reason", "unknown"))
|
||||
finished.append(
|
||||
{
|
||||
"request_id": state.request_id,
|
||||
"profile": dict(request_profile),
|
||||
"semantic_len": semantic_len,
|
||||
"finish_idx": int(job_result.get("finish_idx", job_result.get("decode_steps", 0))),
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
)
|
||||
self.api._complete_request_state(
|
||||
state.request_id,
|
||||
dict(request_profile),
|
||||
)
|
||||
finish_reason_counts[finish_reason] = finish_reason_counts.get(finish_reason, 0) + 1
|
||||
total_semantic_len += semantic_len
|
||||
return SchedulerDebugExecution(
|
||||
payload={
|
||||
"message": "success",
|
||||
"request_count": len(states),
|
||||
"max_steps": int(max_steps),
|
||||
"batch_profile": self.api._build_scheduler_debug_batch_profile(
|
||||
request_count=len(states),
|
||||
max_steps=int(max_steps),
|
||||
prepare_batch_wall_ms=prepare_batch_wall_ms,
|
||||
decode_batch_wall_ms=decode_batch_wall_ms,
|
||||
request_total_ms=request_total_ms,
|
||||
finished_items=finished,
|
||||
),
|
||||
"batch_profile": {
|
||||
"request_count": int(len(states)),
|
||||
"max_steps": int(max_steps),
|
||||
"prepare_batch_wall_ms": float(prepare_batch_wall_ms),
|
||||
"decode_batch_wall_ms": float(decode_batch_wall_ms),
|
||||
"request_total_ms": float(request_total_ms),
|
||||
"total_semantic_len": int(total_semantic_len),
|
||||
"finish_reason_counts": finish_reason_counts,
|
||||
},
|
||||
"requests": self._summarize_scheduler_states(states),
|
||||
"finished": self._summarize_scheduler_finished(finished),
|
||||
"finished": finished,
|
||||
"request_profiles": request_profiles,
|
||||
"request_traces": self.api._collect_request_summaries(request_ids),
|
||||
}
|
||||
@ -222,6 +278,7 @@ class EngineApiSchedulerFlow:
|
||||
speed_factor=float(normalized.speed_factor),
|
||||
sample_steps=int(normalized.sample_steps),
|
||||
media_type=normalized.media_type,
|
||||
super_sampling=bool(normalized.super_sampling),
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=loop,
|
||||
|
||||
@ -149,6 +149,7 @@ class EngineBridgeDelegates:
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
super_sampling: bool,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None,
|
||||
@ -161,6 +162,7 @@ class EngineBridgeDelegates:
|
||||
speed_factor=speed_factor,
|
||||
sample_steps=sample_steps,
|
||||
media_type=media_type,
|
||||
super_sampling=super_sampling,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=done_loop,
|
||||
|
||||
@ -78,6 +78,7 @@ class EngineStageBridgeFacade:
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
super_sampling: bool,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None,
|
||||
@ -90,6 +91,7 @@ class EngineStageBridgeFacade:
|
||||
speed_factor=speed_factor,
|
||||
sample_steps=sample_steps,
|
||||
media_type=media_type,
|
||||
super_sampling=super_sampling,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=done_loop,
|
||||
|
||||
@ -104,6 +104,7 @@ class NormalizedEngineRequest:
|
||||
temperature=self.temperature,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
early_stop_num=self.early_stop_num,
|
||||
aux_ref_audio_paths=list(self.aux_ref_audio_paths or []),
|
||||
ready_step=self.ready_step,
|
||||
)
|
||||
|
||||
|
||||
@ -303,6 +303,7 @@ class SchedulerPendingJob:
|
||||
speed_factor: float
|
||||
sample_steps: int
|
||||
media_type: str
|
||||
super_sampling: bool = False
|
||||
admission_wait_ms: float = 0.0
|
||||
engine_policy_wait_ms: float = 0.0
|
||||
engine_dispatch_wait_ms: float = 0.0
|
||||
|
||||
@ -291,6 +291,7 @@ class EngineDispatchTask:
|
||||
speed_factor: float
|
||||
sample_steps: int
|
||||
media_type: str
|
||||
super_sampling: bool
|
||||
prepare_wall_ms: float
|
||||
prepare_profile_total_ms: float
|
||||
done_loop: asyncio.AbstractEventLoop | None
|
||||
|
||||
@ -113,6 +113,7 @@ class EngineStageCoordinator:
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
super_sampling: bool,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None,
|
||||
@ -125,6 +126,7 @@ class EngineStageCoordinator:
|
||||
speed_factor=speed_factor,
|
||||
sample_steps=sample_steps,
|
||||
media_type=media_type,
|
||||
super_sampling=super_sampling,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=done_loop,
|
||||
|
||||
@ -16,6 +16,7 @@ class EngineDispatchStageMixin:
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
super_sampling: bool,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None,
|
||||
@ -29,6 +30,7 @@ class EngineDispatchStageMixin:
|
||||
speed_factor=float(speed_factor),
|
||||
sample_steps=int(sample_steps),
|
||||
media_type=media_type,
|
||||
super_sampling=bool(super_sampling),
|
||||
prepare_wall_ms=float(prepare_wall_ms),
|
||||
prepare_profile_total_ms=float(prepare_profile_total_ms),
|
||||
done_loop=done_loop,
|
||||
@ -66,6 +68,7 @@ class EngineDispatchStageMixin:
|
||||
speed_factor=dispatch_task.speed_factor,
|
||||
sample_steps=dispatch_task.sample_steps,
|
||||
media_type=dispatch_task.media_type,
|
||||
super_sampling=dispatch_task.super_sampling,
|
||||
prepare_wall_ms=dispatch_task.prepare_wall_ms,
|
||||
prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms,
|
||||
done_loop=dispatch_task.done_loop,
|
||||
|
||||
@ -46,7 +46,7 @@ class UnifiedSchedulerWorker(
|
||||
self.decode_backlog_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_DECODE_BACKLOG_MAX", "0")))
|
||||
self.finalize_pending_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_FINALIZE_PENDING_MAX", "0")))
|
||||
self.engine_decode_control_enabled = (
|
||||
str(os.environ.get("GPTSOVITS_ENGINE_DRIVE_DECODE", "0")).strip().lower() in {"1", "true", "yes", "on"}
|
||||
str(os.environ.get("GPTSOVITS_ENGINE_DRIVE_DECODE", "1")).strip().lower() in {"1", "true", "yes", "on"}
|
||||
)
|
||||
self.job_registry = SchedulerJobRegistry(self.condition)
|
||||
self.worker_thread: threading.Thread | None = None
|
||||
|
||||
@ -149,16 +149,25 @@ class WorkerFinalizeExecutor:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _collect_job_refer_specs(job: SchedulerPendingJob) -> List[tuple]:
|
||||
refer_specs = [job.state.refer_spec]
|
||||
refer_specs.extend(list(getattr(job.state, "aux_refer_specs", []) or []))
|
||||
return refer_specs
|
||||
|
||||
def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]:
|
||||
audio_fragment = self.tts.synthesize_audio_request_local(
|
||||
semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0),
|
||||
phones=job.state.phones.detach().clone().unsqueeze(0),
|
||||
prompt_semantic=job.state.prompt_semantic.detach().clone(),
|
||||
prompt_phones=job.state.prompt_phones.detach().clone(),
|
||||
refer_spec=(
|
||||
job.state.refer_spec[0].detach().clone(),
|
||||
None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(),
|
||||
),
|
||||
refer_spec=[
|
||||
(
|
||||
refer_spec_item[0].detach().clone(),
|
||||
None if refer_spec_item[1] is None else refer_spec_item[1].detach().clone(),
|
||||
)
|
||||
for refer_spec_item in self._collect_job_refer_specs(job)
|
||||
],
|
||||
raw_audio=job.state.raw_audio.detach().clone(),
|
||||
raw_sr=int(job.state.raw_sr),
|
||||
speed=float(job.speed_factor),
|
||||
@ -172,7 +181,7 @@ class WorkerFinalizeExecutor:
|
||||
speed_factor=float(job.speed_factor),
|
||||
split_bucket=False,
|
||||
fragment_interval=0.0,
|
||||
super_sampling=False,
|
||||
super_sampling=bool(job.super_sampling),
|
||||
)
|
||||
|
||||
def _synthesize_finished_audio_batch(
|
||||
@ -185,11 +194,14 @@ class WorkerFinalizeExecutor:
|
||||
speeds = []
|
||||
sample_steps_list = []
|
||||
for job, _ in jobs_and_items:
|
||||
refer_spec_group = self._collect_job_refer_specs(job)
|
||||
if len(refer_spec_group) != 1:
|
||||
raise ValueError("batched finalize 暂不支持单请求多参考音频")
|
||||
refer_specs.append(
|
||||
(
|
||||
job.state.refer_spec[0].detach().clone(),
|
||||
None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(),
|
||||
)
|
||||
[(
|
||||
refer_spec_group[0][0].detach().clone(),
|
||||
None if refer_spec_group[0][1] is None else refer_spec_group[0][1].detach().clone(),
|
||||
)]
|
||||
)
|
||||
speeds.append(float(job.speed_factor))
|
||||
sample_steps_list.append(int(job.sample_steps))
|
||||
@ -211,7 +223,7 @@ class WorkerFinalizeExecutor:
|
||||
speed_factor=float(job.speed_factor),
|
||||
split_bucket=False,
|
||||
fragment_interval=0.0,
|
||||
super_sampling=False,
|
||||
super_sampling=bool(job.super_sampling),
|
||||
)
|
||||
)
|
||||
return results
|
||||
@ -224,9 +236,12 @@ class WorkerFinalizeExecutor:
|
||||
return 0.0, []
|
||||
self._sync_device()
|
||||
synth_start = time.perf_counter()
|
||||
if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder:
|
||||
job, item = jobs_and_items[0]
|
||||
batch_results = [self._synthesize_finished_audio(job, item)]
|
||||
if (
|
||||
len(jobs_and_items) == 1
|
||||
or self.tts.configs.use_vocoder
|
||||
or any(len(self._collect_job_refer_specs(job)) != 1 for job, _ in jobs_and_items)
|
||||
):
|
||||
batch_results = [self._synthesize_finished_audio(job, item) for job, item in jobs_and_items]
|
||||
else:
|
||||
batch_results = self._synthesize_finished_audio_batch(jobs_and_items)
|
||||
self._sync_device()
|
||||
|
||||
@ -78,6 +78,7 @@ class WorkerSubmitLifecycleMixin:
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
super_sampling: bool,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None = None,
|
||||
@ -97,6 +98,7 @@ class WorkerSubmitLifecycleMixin:
|
||||
speed_factor,
|
||||
sample_steps,
|
||||
media_type,
|
||||
super_sampling,
|
||||
prepare_wall_ms,
|
||||
prepare_profile_total_ms,
|
||||
done_loop,
|
||||
@ -172,6 +174,7 @@ class WorkerSubmitLifecycleMixin:
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
super_sampling: bool,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None = None,
|
||||
@ -205,6 +208,7 @@ class WorkerSubmitLifecycleMixin:
|
||||
speed_factor=float(speed_factor),
|
||||
sample_steps=int(sample_steps),
|
||||
media_type=media_type,
|
||||
super_sampling=bool(super_sampling),
|
||||
admission_wait_ms=float(admission_wait_ms),
|
||||
engine_policy_wait_ms=float(engine_policy_wait_ms),
|
||||
engine_dispatch_wait_ms=float(engine_dispatch_wait_ms),
|
||||
|
||||
10
api_v2.py
10
api_v2.py
@ -39,8 +39,8 @@ POST:
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"sample_steps": 32, # int. 仅 v3/v4 vocoder 路径使用;当前 v2/v2ProPlus 主线可忽略。
|
||||
"super_sampling": False, # bool. 仅 v3/v4 路径使用;不属于当前 v2/v2ProPlus 正式支持目标。
|
||||
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
@ -79,7 +79,7 @@ endpoint: `/set_gpt_weights`
|
||||
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||||
```
|
||||
RESP:
|
||||
成功: 返回"success", http code 200
|
||||
@ -92,7 +92,7 @@ endpoint: `/set_sovits_weights`
|
||||
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth
|
||||
```
|
||||
|
||||
RESP:
|
||||
@ -211,7 +211,7 @@ async def tts_handle(req: dict):
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"super_sampling": False, # bool. only for v3/v4; not part of current v2/v2ProPlus mainline.
|
||||
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
|
||||
10
api_v3.py
10
api_v3.py
@ -39,8 +39,8 @@ POST:
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"sample_steps": 32, # int. 仅 v3/v4 vocoder 路径使用;当前 v2/v2ProPlus 主线可忽略。
|
||||
"super_sampling": False, # bool. 仅 v3/v4 路径使用;不属于当前 v2/v2ProPlus 正式支持目标。
|
||||
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
@ -79,7 +79,7 @@ endpoint: `/set_gpt_weights`
|
||||
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||||
```
|
||||
RESP:
|
||||
成功: 返回"success", http code 200
|
||||
@ -92,7 +92,7 @@ endpoint: `/set_sovits_weights`
|
||||
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth
|
||||
```
|
||||
|
||||
RESP:
|
||||
@ -280,7 +280,7 @@ async def tts_handle(req: dict):
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"super_sampling": False, # bool. only for v3/v4; not part of current v2/v2ProPlus mainline.
|
||||
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user