diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 9c259662..d475b804 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1,26 +1,27 @@ import gc +import concurrent.futures import math import os import random import sys import time import traceback -from concurrent.futures import ThreadPoolExecutor from copy import deepcopy -import torchaudio -from tqdm import tqdm - now_dir = os.getcwd() sys.path.append(now_dir) -import os from typing import List, Tuple, Union +from runtime_preload import preload_text_runtime_deps + +preload_text_runtime_deps() + import ffmpeg import librosa import numpy as np import torch import torch.nn.functional as F +import torchaudio import yaml from AR.models.t2s_lightning_module import Text2SemanticLightningModule from BigVGAN.bigvgan import BigVGAN @@ -30,6 +31,7 @@ from module.models import SynthesizerTrn, SynthesizerTrnV3, Generator from peft import LoraConfig, get_peft_model from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new from transformers import AutoModelForMaskedLM, AutoTokenizer +from tqdm import tqdm from tools.audio_sr import AP_BWE from tools.i18n.i18n import I18nAuto, scan_language_list @@ -449,20 +451,21 @@ class TTS: "overlapped_len": None, } self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1"))) - self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "2"))) + self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4"))) self.prepare_bert_batch_worker = None self.prepare_ref_semantic_batch_worker = None - default_text_cpu_workers = 16 self.prepare_text_cpu_workers = max( 0, - int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", str(default_text_cpu_workers))), + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", "0")), ) - self.prepare_text_cpu_executor = None - if self.prepare_text_cpu_workers > 0: - self.prepare_text_cpu_executor = ThreadPoolExecutor( + self.prepare_text_cpu_executor = ( + concurrent.futures.ThreadPoolExecutor( max_workers=self.prepare_text_cpu_workers, thread_name_prefix="prepare-text-cpu", ) + if self.prepare_text_cpu_workers > 0 + else None + ) self._init_models() @@ -475,6 +478,20 @@ class TTS: batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_WINDOW_MS", "5")), max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_ITEMS", "16")), max_batch_tokens=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_TOKENS", "4096")), + max_pending_tasks=int(os.environ.get("GPTSOVITS_PREPARE_BERT_MAX_PENDING_TASKS", "0")), + admission_poll_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_ADMISSION_POLL_MS", "1")), + high_pressure_pending_threshold=int( + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_PENDING_THRESHOLD", "0") + ), + high_pressure_batch_window_ms=int( + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_BATCH_WINDOW_MS", "1") + ), + high_pressure_max_batch_items=int( + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_MAX_ITEMS", "32") + ), + high_pressure_max_batch_tokens=int( + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_MAX_TOKENS", "8192") + ), ) if os.environ.get("GPTSOVITS_PREPARE_REF_BATCHING", "0") != "0": ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_SAMPLES") @@ -830,13 +847,23 @@ class TTS: return codes[0, 0].to(self.configs.device) @torch.inference_mode() - def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + cpu_prepare_start = time.perf_counter() wav16k = prepare_prompt_semantic_wav16k( raw_audio=raw_audio, raw_sr=raw_sr, zero_wav_samples=int(self.configs.sampling_rate * 0.3), ) - return self._extract_prompt_semantic_from_prepared_wav16k(wav16k) + cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0 + forward_start = time.perf_counter() + prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + return prompt_semantic, cpu_prepare_ms, forward_ms + + @torch.inference_mode() + def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + prompt_semantic, _, _ = self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr) + return prompt_semantic def extract_prompt_semantic(self, ref_wav_path: str): raw_audio, raw_sr = self._load_ref_audio_raw(ref_wav_path) @@ -887,7 +914,9 @@ class TTS: if self.prepare_ref_semantic_batch_worker is None: with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats: prompt_semantic_start = time.perf_counter() - prompt_semantic = self._extract_prompt_semantic_from_raw(raw_audio, raw_sr) + prompt_semantic, prompt_semantic_cpu_prepare_ms, prompt_semantic_forward_ms = ( + self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr) + ) prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 ref_spec_start = time.perf_counter() refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] @@ -897,8 +926,8 @@ class TTS: audio_stage_inflight_peak = float(limiter_stats["peak_inflight"]) prompt_semantic_profile = { "prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]), - "prompt_semantic_cpu_prepare_ms": 0.0, - "prompt_semantic_forward_ms": prompt_semantic_ms, + "prompt_semantic_cpu_prepare_ms": float(prompt_semantic_cpu_prepare_ms), + "prompt_semantic_forward_ms": float(prompt_semantic_forward_ms), "prompt_semantic_scatter_ms": 0.0, "prompt_semantic_stage_slots": float(limiter_stats["slots"]), "prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]), @@ -1012,6 +1041,32 @@ class TTS: text, language, self.configs.version, profile=profile ) + def prepare_text_segments(self, text: str, language: str): + return self.text_preprocessor.preprocess_text_segments(text, language, self.configs.version) + + def build_text_features_from_segments(self, prepared_segments, profile: dict | None = None): + return self.text_preprocessor.build_phones_and_bert_from_segments(prepared_segments, profile=profile) + + async def build_text_features_from_segments_async(self, prepared_segments, profile: dict | None = None): + return await self.text_preprocessor.build_phones_and_bert_from_segments_async( + prepared_segments, + profile=profile, + ) + + async def build_text_feature_pair_from_segments_async( + self, + prompt_segments, + target_segments, + prompt_profile: dict | None = None, + target_profile: dict | None = None, + ): + return await self.text_preprocessor.build_phones_and_bert_pair_from_segments_async( + prompt_segments, + target_segments, + prompt_profile=prompt_profile, + target_profile=target_profile, + ) + def _set_ref_audio_path(self, ref_audio_path): self.prompt_cache["ref_audio_path"] = ref_audio_path @@ -2011,6 +2066,79 @@ class TTS: sample_steps=sample_steps, ) + @torch.inference_mode() + def synthesize_audio_requests_local_batched( + self, + semantic_tokens_list: List[torch.Tensor], + phones_list: List[torch.Tensor], + refer_specs: List[tuple], + speeds: List[float] | None = None, + sample_steps_list: List[int] | None = None, + ) -> List[torch.Tensor]: + batch_size = len(semantic_tokens_list) + if batch_size == 0: + return [] + if len(phones_list) != batch_size or len(refer_specs) != batch_size: + raise ValueError("batched request-local synthesis 输入长度不一致") + if speeds is None: + speeds = [1.0] * batch_size + if sample_steps_list is None: + sample_steps_list = [32] * batch_size + if len(speeds) != batch_size or len(sample_steps_list) != batch_size: + raise ValueError("batched request-local synthesis 参数长度不一致") + first_speed = float(speeds[0]) + first_sample_steps = int(sample_steps_list[0]) + if any(abs(float(item) - first_speed) > 1e-6 for item in speeds): + raise ValueError("batched request-local synthesis 目前要求 speed 一致") + if any(int(item) != first_sample_steps for item in sample_steps_list): + raise ValueError("batched request-local synthesis 目前要求 sample_steps 一致") + if self.configs.use_vocoder: + raise NotImplementedError("request-local batched VITS synthesis 暂不支持 vocoder 模型") + + device = self.configs.device + max_semantic_len = max(int(item.shape[-1]) for item in semantic_tokens_list) + max_phone_len = max(int(item.shape[-1]) for item in phones_list) + semantic_batch = torch.zeros((1, batch_size, max_semantic_len), dtype=torch.long, device=device) + phone_batch = torch.zeros((batch_size, max_phone_len), dtype=torch.long, device=device) + semantic_lengths = [] + phone_lengths = [] + refer_audio_specs: List[torch.Tensor] = [] + sv_emb_batch = None + sv_emb_list: List[torch.Tensor] = [] + + for batch_index, semantic_tokens in enumerate(semantic_tokens_list): + semantic_len = int(semantic_tokens.shape[-1]) + phone_len = int(phones_list[batch_index].shape[-1]) + semantic_batch[0, batch_index, :semantic_len] = semantic_tokens.to(device=device, dtype=torch.long) + phone_batch[batch_index, :phone_len] = phones_list[batch_index].to(device=device, dtype=torch.long) + semantic_lengths.append(semantic_len) + phone_lengths.append(phone_len) + + refer_audio_spec, audio_tensor = refer_specs[batch_index] + refer_audio_specs.append(refer_audio_spec.to(dtype=self.precision, device=device)) + if self.is_v2pro: + if audio_tensor is None: + raise ValueError(i18n("v2Pro request-local batched synthesis 缺少 16k 参考音频")) + sv_emb_list.append(self.sv_model.compute_embedding3(audio_tensor).to(device)) + + if self.is_v2pro: + sv_emb_batch = torch.cat(sv_emb_list, dim=0) + + audio_batch, audio_lengths = self.vits_model.decode_batched_request_local( + codes=semantic_batch, + code_lengths=torch.LongTensor(semantic_lengths).to(device), + text=phone_batch, + text_lengths=torch.LongTensor(phone_lengths).to(device), + refer_list=refer_audio_specs, + speed=first_speed, + sv_emb=sv_emb_batch, + ) + audios: List[torch.Tensor] = [] + for batch_index in range(batch_size): + audio_len = int(audio_lengths[batch_index].item()) + audios.append(audio_batch[batch_index, 0, :audio_len].detach()) + return audios + def using_vocoder_synthesis_batched_infer( self, idx_list: List[int], diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 15b3c322..6bee49be 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -1,8 +1,10 @@ +import asyncio import os import sys import threading import time from contextlib import contextmanager +from dataclasses import dataclass from tqdm import tqdm @@ -13,12 +15,13 @@ import re import torch from text.LangSegmenter import LangSegmenter from text import chinese -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple from text.cleaner import clean_text from text import cleaned_text_to_sequence from transformers import AutoModelForMaskedLM, AutoTokenizer from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker +from TTS_infer_pack.text_cpu_preprocess import preprocess_text_segments_payload from tools.i18n.i18n import I18nAuto, scan_language_list @@ -92,6 +95,14 @@ class StageLimiter: } +@dataclass +class PreparedTextSegment: + language: str + phones: List[int] + word2ph: Optional[List[int]] + norm_text: str + + class TextPreprocessor: def __init__( self, @@ -149,7 +160,7 @@ class TextPreprocessor: # 解决输入目标文本的空行导致报错的问题 if len(text.strip()) == 0: continue - if not re.sub("\W+", "", text): + if not re.sub(r"\W+", "", text): # 检测一下,如果是纯符号,就跳过。 continue if text[-1] not in splits: @@ -168,7 +179,8 @@ class TextPreprocessor: def segment_and_extract_feature_for_text( self, text: str, language: str, version: str = "v1", profile: Dict | None = None ) -> Tuple[list, torch.Tensor, str]: - return self.get_phones_and_bert(text, language, version, profile=profile) + prepared_segments = self.preprocess_text_segments(text, language, version) + return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile) def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]: textlist = [] @@ -223,24 +235,49 @@ class TextPreprocessor: def get_phones_and_bert( self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None ): - text = re.sub(r' {2,}', ' ', text) - textlist, langlist = self._split_text_by_language(text, language) - phones_list = [] - bert_list = [] - norm_text_list = [] - for segment_text, segment_lang in zip(textlist, langlist): - phones, word2ph, norm_text = self.clean_text_inf(segment_text, segment_lang, version) - bert = self.get_bert_inf(phones, word2ph, norm_text, segment_lang, profile=profile) - phones_list.append(phones) - norm_text_list.append(norm_text) + prepared_segments = self.preprocess_text_segments(text, language, version, final=final) + return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile) + + def preprocess_text_segments( + self, + text: str, + language: str, + version: str, + final: bool = False, + ) -> List[PreparedTextSegment]: + payloads = preprocess_text_segments_payload(text, language, version, final=final) + return [ + PreparedTextSegment( + language=str(payload["language"]), + phones=list(payload["phones"]), + word2ph=None if payload["word2ph"] is None else list(payload["word2ph"]), + norm_text=str(payload["norm_text"]), + ) + for payload in payloads + ] + + def build_phones_and_bert_from_segments( + self, + prepared_segments: List[PreparedTextSegment], + profile: Dict | None = None, + ) -> Tuple[list, torch.Tensor, str]: + phones_list: List[List[int]] = [] + bert_list: List[torch.Tensor] = [] + norm_text_list: List[str] = [] + for segment in prepared_segments: + bert = self.get_bert_inf( + segment.phones, + segment.word2ph, + segment.norm_text, + segment.language, + profile=profile, + ) + phones_list.append(segment.phones) + norm_text_list.append(segment.norm_text) bert_list.append(bert) bert = torch.cat(bert_list, dim=1) phones = sum(phones_list, []) norm_text = "".join(norm_text_list) - - if not final and len(phones) < 6: - return self.get_phones_and_bert("." + text, language, version, final=True, profile=profile) - return phones, bert, norm_text def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None: @@ -253,21 +290,41 @@ class TextPreprocessor: return profile[key] = float(max(float(profile.get(key, 0.0)), float(value))) + def _merge_bert_worker_profile(self, profile: Dict | None, worker_profile: Dict[str, float]) -> None: + self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0)) + self._accumulate_profile(profile, "bert_admission_wait_ms", worker_profile.get("bert_admission_wait_ms", 0.0)) + self._accumulate_profile(profile, "bert_queue_wait_ms", worker_profile.get("bert_queue_wait_ms", 0.0)) + self._accumulate_profile( + profile, + "bert_batch_collect_wait_ms", + worker_profile.get("bert_batch_collect_wait_ms", 0.0), + ) + self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0)) + self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0)) + self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0)) + self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0)) + self._update_profile_peak(profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0)) + self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0)) + self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0)) + self._update_profile_peak( + profile, + "bert_pending_depth_on_enqueue_peak", + worker_profile.get("bert_pending_depth_on_enqueue", 0.0), + ) + self._update_profile_peak( + profile, + "bert_pending_depth_on_collect_peak", + worker_profile.get("bert_pending_depth_on_collect", 0.0), + ) + self._update_profile_peak(profile, "bert_high_pressure_mode_peak", worker_profile.get("bert_high_pressure_mode", 0.0)) + if profile is not None: + profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0)) + profile["bert_batch_window_ms"] = float(worker_profile.get("bert_batch_window_ms", 0.0)) + def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor: if self.bert_batch_worker is not None: feature, worker_profile = self.bert_batch_worker.submit(text, word2ph) - self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0)) - self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0)) - self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0)) - self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0)) - self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0)) - self._update_profile_peak( - profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0) - ) - self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0)) - self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0)) - if profile is not None: - profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0)) + self._merge_bert_worker_profile(profile, worker_profile) return feature limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0} @@ -310,9 +367,18 @@ class TextPreprocessor: phones = cleaned_text_to_sequence(phones, version) return phones, word2ph, norm_text - def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str, profile: Dict | None = None): + def get_bert_inf( + self, + phones: list, + word2ph: Optional[list], + norm_text: str, + language: str, + profile: Dict | None = None, + ): language = language.replace("all_", "") if language == "zh": + if word2ph is None: + raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征") feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device) else: feature = torch.zeros( @@ -322,6 +388,112 @@ class TextPreprocessor: return feature + async def build_phones_and_bert_from_segments_async( + self, + prepared_segments: List[PreparedTextSegment], + profile: Dict | None = None, + ) -> Tuple[list, torch.Tensor, str]: + segment_jobs = self._build_async_segment_jobs(prepared_segments, profile) + pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = [] + for segment_index, segment in enumerate(prepared_segments): + if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None: + continue + if segment.word2ph is None: + raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征") + pending_items.append( + ( + segment_jobs["bert_list"], + segment_index, + profile, + self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph), + ) + ) + + if pending_items: + pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items]) + for (bert_list, bert_index, item_profile, _), (feature, worker_profile) in zip(pending_items, pending_results): + self._merge_bert_worker_profile(item_profile, worker_profile) + bert_list[bert_index] = feature.to(self.device) + + return self._finalize_async_segment_jobs(segment_jobs) + + def _build_async_segment_jobs( + self, + prepared_segments: List[PreparedTextSegment], + profile: Dict | None, + ) -> Dict[str, List]: + phones_list: List[List[int]] = [] + bert_list: List[torch.Tensor | None] = [] + norm_text_list: List[str] = [] + + for segment in prepared_segments: + phones_list.append(segment.phones) + norm_text_list.append(segment.norm_text) + segment_language = segment.language.replace("all_", "") + if segment_language == "zh" and self.bert_batch_worker is not None: + if segment.word2ph is None: + raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征") + bert_list.append(None) + continue + bert_list.append( + self.get_bert_inf( + segment.phones, + segment.word2ph, + segment.norm_text, + segment.language, + profile=profile, + ) + ) + return { + "phones_list": phones_list, + "bert_list": bert_list, + "norm_text_list": norm_text_list, + } + + @staticmethod + def _finalize_async_segment_jobs(segment_jobs: Dict[str, List]) -> Tuple[list, torch.Tensor, str]: + bert = torch.cat([feature for feature in segment_jobs["bert_list"] if feature is not None], dim=1) + phones = sum(segment_jobs["phones_list"], []) + norm_text = "".join(segment_jobs["norm_text_list"]) + return phones, bert, norm_text + + async def build_phones_and_bert_pair_from_segments_async( + self, + prompt_segments: List[PreparedTextSegment], + target_segments: List[PreparedTextSegment], + prompt_profile: Dict | None = None, + target_profile: Dict | None = None, + ) -> Tuple[Tuple[list, torch.Tensor, str], Tuple[list, torch.Tensor, str]]: + prompt_jobs = self._build_async_segment_jobs(prompt_segments, prompt_profile) + target_jobs = self._build_async_segment_jobs(target_segments, target_profile) + pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = [] + + for segment_jobs, prepared_segments, profile in ( + (prompt_jobs, prompt_segments, prompt_profile), + (target_jobs, target_segments, target_profile), + ): + for segment_index, segment in enumerate(prepared_segments): + if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None: + continue + if segment.word2ph is None: + raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征") + pending_items.append( + ( + segment_jobs["bert_list"], + segment_index, + profile, + self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph), + ) + ) + + if pending_items: + pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items]) + for (bert_list, bert_index, profile, _), (feature, worker_profile) in zip(pending_items, pending_results): + self._merge_bert_worker_profile(profile, worker_profile) + bert_list[bert_index] = feature.to(self.device) + + return self._finalize_async_segment_jobs(prompt_jobs), self._finalize_async_segment_jobs(target_jobs) + def filter_text(self, texts): _text = [] if all(text in [None, " ", "\n", ""] for text in texts): diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py index b1ede3d8..1ac77faa 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py @@ -1,3 +1,4 @@ +import asyncio import threading import time import uuid @@ -14,7 +15,12 @@ class BertFeatureTask: word2ph: List[int] task_id: str = field(default_factory=lambda: uuid.uuid4().hex) created_at: float = field(default_factory=time.perf_counter) + enqueued_at: float = 0.0 + admission_wait_ms: float = 0.0 + pending_depth_on_enqueue: int = 0 done_event: threading.Event = field(default_factory=threading.Event) + done_loop: asyncio.AbstractEventLoop | None = None + done_future: asyncio.Future | None = None result_feature: torch.Tensor | None = None error: Exception | None = None profile: Dict[str, float] = field(default_factory=dict) @@ -30,14 +36,37 @@ class PrepareBertBatchWorker: batch_window_ms: int = 5, max_batch_items: int = 16, max_batch_tokens: int = 4096, + max_pending_tasks: int = 0, + admission_poll_ms: int = 1, + high_pressure_pending_threshold: int = 0, + high_pressure_batch_window_ms: int | None = None, + high_pressure_max_batch_items: int | None = None, + high_pressure_max_batch_tokens: int | None = None, ): self.bert_model = bert_model self.tokenizer = tokenizer self.device = device self.stage_limiter = stage_limiter - self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0) + self.batch_window_ms = max(0, int(batch_window_ms)) + self.batch_window_s = float(self.batch_window_ms) / 1000.0 self.max_batch_items = max(1, int(max_batch_items)) self.max_batch_tokens = max(16, int(max_batch_tokens)) + self.max_pending_tasks = max(0, int(max_pending_tasks)) + self.admission_poll_s = max(0.0005, float(max(1, int(admission_poll_ms))) / 1000.0) + + self.high_pressure_pending_threshold = max( + 0, + int(high_pressure_pending_threshold) + if int(high_pressure_pending_threshold) > 0 + else max(self.max_batch_items * 2, 32), + ) + hp_window_ms = self.batch_window_ms if high_pressure_batch_window_ms is None else int(high_pressure_batch_window_ms) + hp_items = self.max_batch_items if high_pressure_max_batch_items is None else int(high_pressure_max_batch_items) + hp_tokens = self.max_batch_tokens if high_pressure_max_batch_tokens is None else int(high_pressure_max_batch_tokens) + self.high_pressure_batch_window_ms = max(0, hp_window_ms) + self.high_pressure_batch_window_s = float(self.high_pressure_batch_window_ms) / 1000.0 + self.high_pressure_max_batch_items = max(self.max_batch_items, hp_items) + self.high_pressure_max_batch_tokens = max(self.max_batch_tokens, hp_tokens) self.condition = threading.Condition() self.pending_tasks: Deque[BertFeatureTask] = deque() @@ -47,26 +76,70 @@ class PrepareBertBatchWorker: self.total_batches = 0 self.active_batch_size = 0 self.active_batch_peak = 0 + self.active_batch_tokens = 0 + self.active_batch_tokens_peak = 0 + self.high_pressure_batches = 0 + self.admission_wait_total_ms = 0.0 + self.admission_wait_peak_ms = 0.0 self.worker_thread = threading.Thread(target=self._run_loop, name="prepare-bert-batch-worker", daemon=True) self.worker_thread.start() def _estimate_task_tokens(self, task: BertFeatureTask) -> int: return max(1, len(task.norm_text) + 2) + def _can_enqueue_locked(self) -> bool: + if self.max_pending_tasks <= 0: + return True + return (len(self.pending_tasks) + self.active_batch_size) < self.max_pending_tasks + + def _record_enqueue_locked(self, task: BertFeatureTask, admission_wait_ms: float) -> None: + task.admission_wait_ms = float(max(0.0, admission_wait_ms)) + task.enqueued_at = time.perf_counter() + task.pending_depth_on_enqueue = int(len(self.pending_tasks)) + self.pending_tasks.append(task) + self.total_submitted += 1 + self.admission_wait_total_ms += task.admission_wait_ms + self.admission_wait_peak_ms = max(self.admission_wait_peak_ms, task.admission_wait_ms) + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + + def _enqueue_task(self, task: BertFeatureTask) -> None: + admission_started = time.perf_counter() + with self.condition: + while not self._can_enqueue_locked(): + self.condition.wait(timeout=self.admission_poll_s) + self._record_enqueue_locked(task, (time.perf_counter() - admission_started) * 1000.0) + + async def _enqueue_task_async(self, task: BertFeatureTask) -> None: + admission_started = time.perf_counter() + while True: + with self.condition: + if self._can_enqueue_locked(): + self._record_enqueue_locked(task, (time.perf_counter() - admission_started) * 1000.0) + return + await asyncio.sleep(self.admission_poll_s) + def submit(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]: task = BertFeatureTask(norm_text=str(norm_text), word2ph=list(word2ph)) - with self.condition: - self.pending_tasks.append(task) - self.total_submitted += 1 - if len(self.pending_tasks) > self.pending_peak: - self.pending_peak = len(self.pending_tasks) - self.condition.notify_all() + self._enqueue_task(task) task.done_event.wait() if task.error is not None: raise task.error assert task.result_feature is not None return task.result_feature, dict(task.profile) + async def submit_async(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]: + loop = asyncio.get_running_loop() + task = BertFeatureTask( + norm_text=str(norm_text), + word2ph=list(word2ph), + done_loop=loop, + done_future=loop.create_future(), + ) + await self._enqueue_task_async(task) + return await task.done_future + def snapshot(self) -> Dict[str, int]: with self.condition: return { @@ -77,21 +150,57 @@ class PrepareBertBatchWorker: "total_batches": self.total_batches, "active_batch_size": self.active_batch_size, "active_batch_peak": self.active_batch_peak, + "active_batch_tokens": self.active_batch_tokens, + "active_batch_tokens_peak": self.active_batch_tokens_peak, "batch_window_ms": int(self.batch_window_s * 1000.0), "max_batch_items": self.max_batch_items, "max_batch_tokens": self.max_batch_tokens, + "max_pending_tasks": self.max_pending_tasks, + "high_pressure_pending_threshold": self.high_pressure_pending_threshold, + "high_pressure_batch_window_ms": self.high_pressure_batch_window_ms, + "high_pressure_max_batch_items": self.high_pressure_max_batch_items, + "high_pressure_max_batch_tokens": self.high_pressure_max_batch_tokens, + "high_pressure_batches": self.high_pressure_batches, + "admission_wait_total_ms": self.admission_wait_total_ms, + "admission_wait_peak_ms": self.admission_wait_peak_ms, } - def _collect_batch(self) -> List[BertFeatureTask]: + def _select_batch_policy_locked(self) -> Tuple[float, int, int, bool, int]: + pending_depth = len(self.pending_tasks) + use_high_pressure = ( + self.high_pressure_pending_threshold > 0 + and pending_depth >= self.high_pressure_pending_threshold + ) + if use_high_pressure: + return ( + self.high_pressure_batch_window_s, + self.high_pressure_max_batch_items, + self.high_pressure_max_batch_tokens, + True, + pending_depth, + ) + return ( + self.batch_window_s, + self.max_batch_items, + self.max_batch_tokens, + False, + pending_depth, + ) + + def _collect_batch(self) -> Tuple[List[BertFeatureTask], Dict[str, float]]: with self.condition: while not self.pending_tasks: self.condition.wait() + collect_started = time.perf_counter() + batch_window_s, max_batch_items, max_batch_tokens, use_high_pressure, pending_depth_on_collect = ( + self._select_batch_policy_locked() + ) batch: List[BertFeatureTask] = [self.pending_tasks.popleft()] batch_tokens = self._estimate_task_tokens(batch[0]) - deadline = time.perf_counter() + self.batch_window_s + deadline = time.perf_counter() + batch_window_s - while len(batch) < self.max_batch_items: + while len(batch) < max_batch_items: remaining = deadline - time.perf_counter() if remaining <= 0: break @@ -100,26 +209,39 @@ class PrepareBertBatchWorker: continue next_task = self.pending_tasks[0] next_tokens = self._estimate_task_tokens(next_task) - if len(batch) >= self.max_batch_items or (batch_tokens + next_tokens) > self.max_batch_tokens: + if len(batch) >= max_batch_items or (batch_tokens + next_tokens) > max_batch_tokens: break batch.append(self.pending_tasks.popleft()) batch_tokens += next_tokens self.active_batch_size = len(batch) + self.active_batch_tokens = batch_tokens if self.active_batch_size > self.active_batch_peak: self.active_batch_peak = self.active_batch_size - return batch + if self.active_batch_tokens > self.active_batch_tokens_peak: + self.active_batch_tokens_peak = self.active_batch_tokens + if use_high_pressure: + self.high_pressure_batches += 1 + return batch, { + "collect_wait_ms": (time.perf_counter() - collect_started) * 1000.0, + "batch_tokens": float(batch_tokens), + "pending_depth_on_collect": float(pending_depth_on_collect), + "high_pressure_mode": 1.0 if use_high_pressure else 0.0, + "batch_window_ms": float(self.high_pressure_batch_window_ms if use_high_pressure else self.batch_window_ms), + } def _finalize_batch(self, batch: List[BertFeatureTask]) -> None: with self.condition: self.active_batch_size = 0 + self.active_batch_tokens = 0 self.total_batches += 1 self.total_finished += len(batch) + self.condition.notify_all() - def _run_batch(self, batch: List[BertFeatureTask]) -> None: + def _run_batch(self, batch: List[BertFeatureTask], batch_meta: Dict[str, float]) -> None: batch_started = time.perf_counter() texts = [task.norm_text for task in batch] - batch_tokens = sum(self._estimate_task_tokens(task) for task in batch) + batch_tokens = int(batch_meta["batch_tokens"]) limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} if self.stage_limiter is None: @@ -167,6 +289,9 @@ class PrepareBertBatchWorker: task.result_feature = torch.cat(phone_level_feature, dim=0).T task.profile = { "bert_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]), + "bert_admission_wait_ms": float(task.admission_wait_ms), + "bert_queue_wait_ms": max(0.0, (batch_started - task.enqueued_at) * 1000.0), + "bert_batch_collect_wait_ms": float(batch_meta["collect_wait_ms"]), "bert_forward_ms": float(forward_ms), "bert_tokenize_ms": float(tokenize_ms), "bert_scatter_ms": 0.0, @@ -175,6 +300,10 @@ class PrepareBertBatchWorker: "bert_stage_inflight_peak": float(limiter_stats["peak_inflight"]), "bert_batch_size": float(len(batch)), "bert_batch_tokens": float(batch_tokens), + "bert_pending_depth_on_enqueue": float(task.pending_depth_on_enqueue), + "bert_pending_depth_on_collect": float(batch_meta["pending_depth_on_collect"]), + "bert_high_pressure_mode": float(batch_meta["high_pressure_mode"]), + "bert_batch_window_ms": float(batch_meta["batch_window_ms"]), } except Exception as exc: # noqa: PERF203 task.error = exc @@ -183,15 +312,35 @@ class PrepareBertBatchWorker: if task.result_feature is not None: task.profile["bert_scatter_ms"] = float(scatter_ms) task.done_event.set() + self._notify_done_future(task) + + @staticmethod + def _resolve_done_future(task: BertFeatureTask) -> None: + if task.done_future is None or task.done_future.done(): + return + if task.error is not None: + task.done_future.set_exception(task.error) + return + assert task.result_feature is not None + task.done_future.set_result((task.result_feature, dict(task.profile))) + + def _notify_done_future(self, task: BertFeatureTask) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_done_future, task) + except RuntimeError: + pass def _run_loop(self) -> None: while True: - batch = self._collect_batch() + batch, batch_meta = self._collect_batch() try: - self._run_batch(batch) + self._run_batch(batch, batch_meta) except Exception as exc: # noqa: PERF203 for task in batch: task.error = exc task.done_event.set() + self._notify_done_future(task) finally: self._finalize_batch(batch) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py new file mode 100644 index 00000000..1fdf95c5 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py @@ -0,0 +1,294 @@ +from __future__ import annotations + +import asyncio +import concurrent.futures +import os +import threading +import time +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( + PreparedTextFeatures, + SchedulerRequestSpec, + T2SRequestState, + build_request_state_from_parts, + normalize_sentence, +) + + +@dataclass +class ProfiledResult: + result: Any + submit_at: float + started_at: float + finished_at: float + + @property + def queue_ms(self) -> float: + return max(0.0, (self.started_at - self.submit_at) * 1000.0) + + @property + def run_ms(self) -> float: + return max(0.0, (self.finished_at - self.started_at) * 1000.0) + + +class PrepareCoordinator: + def __init__(self, tts: Any): + self.tts = tts + self.lock = threading.Lock() + self.inflight = 0 + self.peak_inflight = 0 + self.use_async_text_feature_path = bool( + getattr(tts, "prepare_bert_batch_worker", None) is not None + and os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_DIRECT", "0") != "0" + ) + self.max_inflight = max(0, int(os.environ.get("GPTSOVITS_PREPARE_MAX_INFLIGHT", "0"))) + self._inflight_semaphore = asyncio.Semaphore(self.max_inflight) if self.max_inflight > 0 else None + self.text_feature_workers = 0 + self.text_feature_executor = None + if not self.use_async_text_feature_path: + text_feature_default_workers = max(1, int(getattr(tts, "prepare_text_cpu_workers", 16) or 16)) + self.text_feature_workers = max( + 1, + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_WORKERS", str(text_feature_default_workers))), + ) + self.text_feature_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=self.text_feature_workers, + thread_name_prefix="prepare-text-feature", + ) + ref_audio_default_workers = max(1, int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4"))) + self.ref_audio_workers = max( + 1, + int(os.environ.get("GPTSOVITS_PREPARE_REF_ASYNC_WORKERS", str(ref_audio_default_workers))), + ) + self.ref_audio_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=self.ref_audio_workers, + thread_name_prefix="prepare-ref-audio", + ) + + def _mark_enter(self) -> Tuple[int, int]: + with self.lock: + self.inflight += 1 + current_inflight = self.inflight + if current_inflight > self.peak_inflight: + self.peak_inflight = current_inflight + return current_inflight, self.peak_inflight + + def _mark_leave(self) -> None: + with self.lock: + self.inflight = max(0, self.inflight - 1) + + def snapshot(self) -> Dict[str, int]: + with self.lock: + return { + "inflight": int(self.inflight), + "peak_inflight": int(self.peak_inflight), + "max_inflight": int(self.max_inflight), + "text_feature_workers": int(self.text_feature_workers), + "ref_audio_workers": int(self.ref_audio_workers), + } + + @staticmethod + def _run_profiled(fn, submit_at: float, *args) -> ProfiledResult: + started_at = time.perf_counter() + result = fn(*args) + finished_at = time.perf_counter() + return ProfiledResult( + result=result, + submit_at=float(submit_at), + started_at=float(started_at), + finished_at=float(finished_at), + ) + + def _prepare_text_cpu(self, text: str, language: str): + return self.tts.prepare_text_segments(text, language) + + def _build_text_features(self, prepared_segments, language: str, cpu_run_ms: float) -> PreparedTextFeatures: + profile: Dict[str, float] = {"cpu_preprocess_ms": float(cpu_run_ms)} + branch_start = time.perf_counter() + phones, bert_features, norm_text = self.tts.build_text_features_from_segments(prepared_segments, profile=profile) + total_ms = float(cpu_run_ms + (time.perf_counter() - branch_start) * 1000.0) + profile["bert_total_ms"] = max(0.0, total_ms - float(cpu_run_ms)) + return PreparedTextFeatures( + phones=phones, + bert_features=bert_features, + norm_text=norm_text, + profile=profile, + total_ms=total_ms, + cpu_preprocess_ms=float(cpu_run_ms), + ) + + async def _run_on_executor(self, executor, fn, *args) -> ProfiledResult: + loop = asyncio.get_running_loop() + submit_at = time.perf_counter() + return await loop.run_in_executor(executor, self._run_profiled, fn, float(submit_at), *args) + + async def _run_text_cpu_stage(self, text: str, language: str) -> ProfiledResult: + executor = getattr(self.tts, "prepare_text_cpu_executor", None) + if executor is None: + submit_at = time.perf_counter() + return self._run_profiled(self._prepare_text_cpu, submit_at, text, language) + return await self._run_on_executor(executor, self._prepare_text_cpu, text, language) + + async def _run_text_feature_stage(self, prepared_segments, language: str, cpu_run_ms: float) -> ProfiledResult: + return await self._run_on_executor(self.text_feature_executor, self._build_text_features, prepared_segments, language, cpu_run_ms) + + @staticmethod + def _estimate_text_feature_run_ms(profile: Dict[str, float]) -> float: + return float( + profile.get("bert_wait_ms", 0.0) + + profile.get("bert_tokenize_ms", 0.0) + + profile.get("bert_forward_ms", 0.0) + + profile.get("bert_scatter_ms", 0.0) + ) + + async def _run_text_feature_pair_stage( + self, + prompt_segments, + target_segments, + prompt_cpu_run_ms: float, + target_cpu_run_ms: float, + ) -> tuple[ProfiledResult, ProfiledResult]: + if self.text_feature_executor is not None: + prompt_feature_task = asyncio.create_task( + self._run_text_feature_stage(prompt_segments, None, prompt_cpu_run_ms) + ) + target_feature_task = asyncio.create_task( + self._run_text_feature_stage(target_segments, None, target_cpu_run_ms) + ) + return await asyncio.gather(prompt_feature_task, target_feature_task) + + prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)} + target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)} + submit_at = time.perf_counter() + started_at = float(submit_at) + prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async( + prompt_segments, + target_segments, + prompt_profile=prompt_profile, + target_profile=target_profile, + ) + finished_at = time.perf_counter() + + prompt_result = PreparedTextFeatures( + phones=prompt_result_raw[0], + bert_features=prompt_result_raw[1], + norm_text=prompt_result_raw[2], + profile=prompt_profile, + total_ms=float(prompt_cpu_run_ms + self._estimate_text_feature_run_ms(prompt_profile)), + cpu_preprocess_ms=float(prompt_cpu_run_ms), + ) + target_result = PreparedTextFeatures( + phones=target_result_raw[0], + bert_features=target_result_raw[1], + norm_text=target_result_raw[2], + profile=target_profile, + total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), + cpu_preprocess_ms=float(target_cpu_run_ms), + ) + prompt_profiled = ProfiledResult( + result=prompt_result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(submit_at + self._estimate_text_feature_run_ms(prompt_profile) / 1000.0), + ) + target_profiled = ProfiledResult( + result=target_result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0), + ) + if finished_at > prompt_profiled.finished_at: + prompt_result.profile["bert_total_ms"] = max( + self._estimate_text_feature_run_ms(prompt_profile), + (finished_at - submit_at) * 1000.0, + ) + target_result.profile["bert_total_ms"] = max( + self._estimate_text_feature_run_ms(target_profile), + (finished_at - submit_at) * 1000.0, + ) + else: + prompt_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(prompt_profile) + target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) + return prompt_profiled, target_profiled + + async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult: + return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path) + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + admission_start = time.perf_counter() + if self._inflight_semaphore is not None: + await self._inflight_semaphore.acquire() + prepare_admission_wait_ms = max(0.0, (time.perf_counter() - admission_start) * 1000.0) + current_inflight, peak_inflight = self._mark_enter() + prepare_start = time.perf_counter() + prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) + text = spec.text.strip("\n") + try: + text_pair_start = time.perf_counter() + prompt_cpu_task = asyncio.create_task(self._run_text_cpu_stage(prompt_text, spec.prompt_lang)) + target_cpu_task = asyncio.create_task(self._run_text_cpu_stage(text, spec.text_lang)) + ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(spec.ref_audio_path))) + prompt_cpu_profiled, target_cpu_profiled = await asyncio.gather(prompt_cpu_task, target_cpu_task) + text_feature_pair_task = asyncio.create_task( + self._run_text_feature_pair_stage( + prompt_cpu_profiled.result, + target_cpu_profiled.result, + prompt_cpu_profiled.run_ms, + target_cpu_profiled.run_ms, + ) + ) + (prompt_feature_profiled, target_feature_profiled), ref_audio_profiled = await asyncio.gather( + text_feature_pair_task, + ref_audio_task, + ) + text_pair_end = time.perf_counter() + state = build_request_state_from_parts( + tts=self.tts, + spec=spec, + prompt_text=prompt_text, + text=text, + prompt_result=prompt_feature_profiled.result, + target_result=target_feature_profiled.result, + ref_audio_bundle=ref_audio_profiled.result, + prepare_start=prepare_start, + prepare_sync_start=prepare_start, + profile_overrides={ + "executor_queue_ms": max(0.0, (prepare_start - prepare_submit_at) * 1000.0), + "prepare_admission_wait_ms": prepare_admission_wait_ms, + "executor_run_wall_ms": max(0.0, (time.perf_counter() - prepare_start) * 1000.0), + "text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0), + "prompt_text_parallel_future_wait_ms": 0.0, + "prompt_text_parallel_future_executor_queue_ms": 0.0, + "prompt_text_parallel_future_run_ms": 0.0, + "prompt_text_parallel_future_finish_after_submit_ms": 0.0, + "prompt_text_parallel_future_queue_tail_after_target_ms": 0.0, + "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, + "prompt_text_cpu_queue_ms": prompt_cpu_profiled.queue_ms, + "prompt_text_cpu_run_ms": prompt_cpu_profiled.run_ms, + "prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms, + "prompt_text_feature_run_ms": prompt_feature_profiled.run_ms, + "text_cpu_queue_ms": target_cpu_profiled.queue_ms, + "text_cpu_run_ms": target_cpu_profiled.run_ms, + "text_feature_queue_ms": target_feature_profiled.queue_ms, + "text_feature_run_ms": target_feature_profiled.run_ms, + "ref_audio_task_queue_ms": ref_audio_profiled.queue_ms, + "ref_audio_task_run_ms": ref_audio_profiled.run_ms, + "worker_prepare_inflight_on_enter": float(current_inflight), + "worker_prepare_peak_inflight": float(peak_inflight), + }, + ) + prepare_exec_finished_at = time.perf_counter() + state.prepare_profile["executor_run_wall_ms"] = max( + 0.0, (prepare_exec_finished_at - prepare_start) * 1000.0 + ) + return state, prepare_start, prepare_exec_finished_at + finally: + self._mark_leave() + if self._inflight_semaphore is not None: + self._inflight_semaphore.release() diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index b7118a72..8aabd286 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -1,6 +1,5 @@ from __future__ import annotations -from concurrent.futures import Future from dataclasses import dataclass from pathlib import Path import time @@ -89,20 +88,31 @@ class T2SFinishedItem: class T2SActiveBatch: request_ids: List[str] states: List[T2SRequestState] - x: torch.Tensor - x_lens: torch.LongTensor + x: Optional[torch.Tensor] + x_lens: Optional[torch.LongTensor] y_sequences: List[torch.LongTensor] prefix_lens: torch.LongTensor xy_pos: torch.Tensor - key_padding_mask: torch.Tensor - prefill_attn_mask: torch.Tensor + key_padding_mask: Optional[torch.Tensor] + prefill_attn_mask: Optional[torch.Tensor] decode_attn_mask: Optional[torch.Tensor] k_cache: Optional[List[torch.Tensor]] v_cache: Optional[List[torch.Tensor]] - step_idx: int + kv_lens: Optional[torch.LongTensor] + step_indices: torch.LongTensor prefill_done: bool +@dataclass +class PreparedTextFeatures: + phones: List[int] + bert_features: torch.Tensor + norm_text: str + profile: Dict[str, float] + total_ms: float + cpu_preprocess_ms: float + + def normalize_sentence(text: str, language: str) -> str: text = text.strip("\n").strip() if not text: @@ -113,105 +123,125 @@ def normalize_sentence(text: str, language: str) -> str: @torch.inference_mode() -def prepare_request_state( +def prepare_text_features( + tts: Any, + text: str, + language: str, +) -> PreparedTextFeatures: + device = tts.configs.device + profile: Dict[str, float] = {} + branch_start = time.perf_counter() + _sync_device(device) + cpu_start = time.perf_counter() + prepared_segments = tts.prepare_text_segments(text, language) + _sync_device(device) + cpu_preprocess_ms = (time.perf_counter() - cpu_start) * 1000.0 + profile["cpu_preprocess_ms"] = float(cpu_preprocess_ms) + bert_start = time.perf_counter() + phones, bert_features, norm_text = tts.build_text_features_from_segments(prepared_segments, profile=profile) + _sync_device(device) + profile["bert_total_ms"] = (time.perf_counter() - bert_start) * 1000.0 + total_ms = (time.perf_counter() - branch_start) * 1000.0 + return PreparedTextFeatures( + phones=phones, + bert_features=bert_features, + norm_text=norm_text, + profile=profile, + total_ms=float(total_ms), + cpu_preprocess_ms=float(cpu_preprocess_ms), + ) + + +@torch.inference_mode() +def build_request_state_from_parts( tts: Any, spec: SchedulerRequestSpec, + prompt_text: str, + text: str, + prompt_result: PreparedTextFeatures, + target_result: PreparedTextFeatures, + ref_audio_bundle: Dict[str, Any], + prepare_start: float, + prepare_sync_start: float, + profile_overrides: Optional[Dict[str, float]] = None, ) -> T2SRequestState: device = tts.configs.device - prepare_start = time.perf_counter() _sync_device(device) - prepare_sync_start = time.perf_counter() - prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) - text = spec.text.strip("\n") - - prompt_text_profile: Dict[str, float] = {} - text_features_profile: Dict[str, float] = {} - text_feature_pair_start = time.perf_counter() - prompt_future: Future | None = None - - def _extract_prompt_features(): - _sync_device(device) - prompt_start = time.perf_counter() - result = tts.extract_text_features(prompt_text, spec.prompt_lang, profile=prompt_text_profile) - _sync_device(device) - return result, (time.perf_counter() - prompt_start) * 1000.0 - - if getattr(tts, "prepare_text_cpu_executor", None) is not None: - prompt_future = tts.prepare_text_cpu_executor.submit(_extract_prompt_features) - - _sync_device(device) - text_features_start = time.perf_counter() - phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang, profile=text_features_profile) - _sync_device(device) - text_features_ms = (time.perf_counter() - text_features_start) * 1000.0 - - if prompt_future is None: - _sync_device(device) - prompt_text_features_start = time.perf_counter() - prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features( - prompt_text, spec.prompt_lang, profile=prompt_text_profile - ) - _sync_device(device) - prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0 - prompt_text_profile["parallel_future_wait_ms"] = 0.0 - else: - prompt_wait_start = time.perf_counter() - (prompt_phones, prompt_bert_features, prompt_norm_text), prompt_text_features_ms = prompt_future.result() - prompt_text_profile["parallel_future_wait_ms"] = (time.perf_counter() - prompt_wait_start) * 1000.0 - - text_feature_pair_ms = (time.perf_counter() - text_feature_pair_start) * 1000.0 - if phones is None: - raise ValueError(f"{spec.request_id} text preprocessing returned no phones") - - _sync_device(device) - ref_audio_bundle_start = time.perf_counter() - ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path)) + ref_audio_bundle_ms = float(ref_audio_bundle.get("profile", {}).get("bundle_total_ms", 0.0)) + bundle_profile = ref_audio_bundle.get("profile", {}) prompt_semantic = ref_audio_bundle["prompt_semantic"].long() spec_audio, audio_16k = ref_audio_bundle["refer_spec"] raw_audio = ref_audio_bundle["raw_audio"] raw_sr = int(ref_audio_bundle["raw_sr"]) - _sync_device(device) - ref_audio_bundle_ms = (time.perf_counter() - ref_audio_bundle_start) * 1000.0 - bundle_profile = ref_audio_bundle.get("profile", {}) prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms)) ref_spec_ms = float(bundle_profile.get("ref_spec_ms", 0.0)) audio_load_ms = float(bundle_profile.get("audio_load_ms", 0.0)) _sync_device(device) tensorize_start = time.perf_counter() - phones_tensor = torch.LongTensor(phones).to(tts.configs.device) - prompt_phones_tensor = torch.LongTensor(prompt_phones).to(tts.configs.device) - all_phones = torch.LongTensor(prompt_phones + phones).to(tts.configs.device) - all_bert_features = torch.cat([prompt_bert_features, bert_features], dim=1).to( + phones_tensor = torch.LongTensor(target_result.phones).to(tts.configs.device) + prompt_phones_tensor = torch.LongTensor(prompt_result.phones).to(tts.configs.device) + all_phones = torch.LongTensor(prompt_result.phones + target_result.phones).to(tts.configs.device) + all_bert_features = torch.cat([prompt_result.bert_features, target_result.bert_features], dim=1).to( dtype=tts.precision, device=tts.configs.device ) _sync_device(device) tensorize_ms = (time.perf_counter() - tensorize_start) * 1000.0 - _sync_device(device) prepare_profile = { - "prompt_text_features_ms": prompt_text_features_ms, - "text_features_ms": text_features_ms, - "prompt_text_bert_wait_ms": float(prompt_text_profile.get("bert_wait_ms", 0.0)), - "prompt_text_bert_forward_ms": float(prompt_text_profile.get("bert_forward_ms", 0.0)), - "prompt_text_bert_tokenize_ms": float(prompt_text_profile.get("bert_tokenize_ms", 0.0)), - "prompt_text_bert_scatter_ms": float(prompt_text_profile.get("bert_scatter_ms", 0.0)), - "prompt_text_bert_calls": float(prompt_text_profile.get("bert_calls", 0.0)), - "prompt_text_bert_stage_slots": float(prompt_text_profile.get("bert_stage_slots", 0.0)), - "prompt_text_bert_stage_inflight_peak": float(prompt_text_profile.get("bert_stage_inflight_peak", 0.0)), - "prompt_text_bert_batch_size_peak": float(prompt_text_profile.get("bert_batch_size_peak", 0.0)), - "prompt_text_bert_batch_tokens_peak": float(prompt_text_profile.get("bert_batch_tokens_peak", 0.0)), - "prompt_text_parallel_future_wait_ms": float(prompt_text_profile.get("parallel_future_wait_ms", 0.0)), - "text_bert_wait_ms": float(text_features_profile.get("bert_wait_ms", 0.0)), - "text_bert_forward_ms": float(text_features_profile.get("bert_forward_ms", 0.0)), - "text_bert_tokenize_ms": float(text_features_profile.get("bert_tokenize_ms", 0.0)), - "text_bert_scatter_ms": float(text_features_profile.get("bert_scatter_ms", 0.0)), - "text_bert_calls": float(text_features_profile.get("bert_calls", 0.0)), - "text_bert_stage_slots": float(text_features_profile.get("bert_stage_slots", 0.0)), - "text_bert_stage_inflight_peak": float(text_features_profile.get("bert_stage_inflight_peak", 0.0)), - "text_bert_batch_size_peak": float(text_features_profile.get("bert_batch_size_peak", 0.0)), - "text_bert_batch_tokens_peak": float(text_features_profile.get("bert_batch_tokens_peak", 0.0)), - "text_feature_pair_ms": text_feature_pair_ms, + "prompt_text_features_ms": float(prompt_result.total_ms), + "text_features_ms": float(target_result.total_ms), + "prompt_text_cpu_preprocess_ms": float(prompt_result.cpu_preprocess_ms), + "text_cpu_preprocess_ms": float(target_result.cpu_preprocess_ms), + "prompt_text_bert_wait_ms": float(prompt_result.profile.get("bert_wait_ms", 0.0)), + "prompt_text_bert_admission_wait_ms": float(prompt_result.profile.get("bert_admission_wait_ms", 0.0)), + "prompt_text_bert_queue_wait_ms": float(prompt_result.profile.get("bert_queue_wait_ms", 0.0)), + "prompt_text_bert_batch_collect_wait_ms": float(prompt_result.profile.get("bert_batch_collect_wait_ms", 0.0)), + "prompt_text_bert_forward_ms": float(prompt_result.profile.get("bert_forward_ms", 0.0)), + "prompt_text_bert_tokenize_ms": float(prompt_result.profile.get("bert_tokenize_ms", 0.0)), + "prompt_text_bert_scatter_ms": float(prompt_result.profile.get("bert_scatter_ms", 0.0)), + "prompt_text_bert_calls": float(prompt_result.profile.get("bert_calls", 0.0)), + "prompt_text_bert_stage_slots": float(prompt_result.profile.get("bert_stage_slots", 0.0)), + "prompt_text_bert_stage_inflight_peak": float(prompt_result.profile.get("bert_stage_inflight_peak", 0.0)), + "prompt_text_bert_batch_size_peak": float(prompt_result.profile.get("bert_batch_size_peak", 0.0)), + "prompt_text_bert_batch_tokens_peak": float(prompt_result.profile.get("bert_batch_tokens_peak", 0.0)), + "prompt_text_bert_pending_depth_on_enqueue_peak": float( + prompt_result.profile.get("bert_pending_depth_on_enqueue_peak", 0.0) + ), + "prompt_text_bert_pending_depth_on_collect_peak": float( + prompt_result.profile.get("bert_pending_depth_on_collect_peak", 0.0) + ), + "prompt_text_bert_high_pressure_mode_peak": float( + prompt_result.profile.get("bert_high_pressure_mode_peak", 0.0) + ), + "prompt_text_bert_batch_window_ms": float(prompt_result.profile.get("bert_batch_window_ms", 0.0)), + "prompt_text_parallel_future_wait_ms": 0.0, + "prompt_text_parallel_future_executor_queue_ms": 0.0, + "prompt_text_parallel_future_run_ms": float(prompt_result.total_ms), + "prompt_text_parallel_future_finish_after_submit_ms": float(prompt_result.total_ms), + "prompt_text_parallel_future_queue_tail_after_target_ms": 0.0, + "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, + "text_bert_wait_ms": float(target_result.profile.get("bert_wait_ms", 0.0)), + "text_bert_admission_wait_ms": float(target_result.profile.get("bert_admission_wait_ms", 0.0)), + "text_bert_queue_wait_ms": float(target_result.profile.get("bert_queue_wait_ms", 0.0)), + "text_bert_batch_collect_wait_ms": float(target_result.profile.get("bert_batch_collect_wait_ms", 0.0)), + "text_bert_forward_ms": float(target_result.profile.get("bert_forward_ms", 0.0)), + "text_bert_tokenize_ms": float(target_result.profile.get("bert_tokenize_ms", 0.0)), + "text_bert_scatter_ms": float(target_result.profile.get("bert_scatter_ms", 0.0)), + "text_bert_calls": float(target_result.profile.get("bert_calls", 0.0)), + "text_bert_stage_slots": float(target_result.profile.get("bert_stage_slots", 0.0)), + "text_bert_stage_inflight_peak": float(target_result.profile.get("bert_stage_inflight_peak", 0.0)), + "text_bert_batch_size_peak": float(target_result.profile.get("bert_batch_size_peak", 0.0)), + "text_bert_batch_tokens_peak": float(target_result.profile.get("bert_batch_tokens_peak", 0.0)), + "text_bert_pending_depth_on_enqueue_peak": float( + target_result.profile.get("bert_pending_depth_on_enqueue_peak", 0.0) + ), + "text_bert_pending_depth_on_collect_peak": float( + target_result.profile.get("bert_pending_depth_on_collect_peak", 0.0) + ), + "text_bert_high_pressure_mode_peak": float(target_result.profile.get("bert_high_pressure_mode_peak", 0.0)), + "text_bert_batch_window_ms": float(target_result.profile.get("bert_batch_window_ms", 0.0)), + "text_feature_pair_ms": float(max(prompt_result.total_ms, target_result.total_ms)), "text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)), "audio_load_ms": audio_load_ms, "audio_stage_wait_ms": float(bundle_profile.get("audio_stage_wait_ms", 0.0)), @@ -233,6 +263,8 @@ def prepare_request_state( "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, "wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0, } + if profile_overrides: + prepare_profile.update({key: float(value) for key, value in profile_overrides.items()}) return T2SRequestState( request_id=spec.request_id, ref_audio_path=spec.ref_audio_path, @@ -240,8 +272,8 @@ def prepare_request_state( prompt_lang=spec.prompt_lang, text=text, text_lang=spec.text_lang, - norm_prompt_text=prompt_norm_text, - norm_text=norm_text, + norm_prompt_text=prompt_result.norm_text, + norm_text=target_result.norm_text, phones=phones_tensor, prompt_phones=prompt_phones_tensor, all_phones=all_phones, @@ -260,6 +292,33 @@ def prepare_request_state( ) +@torch.inference_mode() +def prepare_request_state( + tts: Any, + spec: SchedulerRequestSpec, +) -> T2SRequestState: + prepare_start = time.perf_counter() + prepare_sync_start = time.perf_counter() + prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) + text = spec.text.strip("\n") + prompt_result = prepare_text_features(tts, prompt_text, spec.prompt_lang) + target_result = prepare_text_features(tts, text, spec.text_lang) + if target_result.phones is None: + raise ValueError(f"{spec.request_id} text preprocessing returned no phones") + ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path)) + return build_request_state_from_parts( + tts=tts, + spec=spec, + prompt_text=prompt_text, + text=text, + prompt_result=prompt_result, + target_result=target_result, + ref_audio_bundle=ref_audio_bundle, + prepare_start=prepare_start, + prepare_sync_start=prepare_sync_start, + ) + + def _left_pad_hidden(hidden: torch.Tensor, target_len: int) -> torch.Tensor: if hidden.shape[0] >= target_len: return hidden @@ -417,7 +476,8 @@ def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SAct decode_attn_mask=None, k_cache=None, v_cache=None, - step_idx=0, + kv_lens=None, + step_indices=torch.zeros((len(states),), dtype=torch.long, device=device), prefill_done=False, ) @@ -433,6 +493,64 @@ def build_next_xy_pos(model: Any, y_sequences: Sequence[torch.LongTensor]) -> to ) +def _compact_cache_to_kv_lens( + cache: torch.Tensor, + kv_lens: torch.LongTensor, +) -> torch.Tensor: + target_len = int(kv_lens.max().item()) + if cache.shape[1] == target_len and torch.all(kv_lens == target_len).item(): + return cache + compacted = cache.new_zeros((cache.shape[0], target_len, cache.shape[2])) + for batch_index, kv_len in enumerate(kv_lens.tolist()): + if kv_len <= 0: + continue + compacted[batch_index, -kv_len:, :] = cache[batch_index, -kv_len:, :] + return compacted + + +def _compact_decode_mask_to_kv_lens( + decode_attn_mask: Optional[torch.Tensor], + kv_lens: torch.LongTensor, +) -> Optional[torch.Tensor]: + target_len = int(kv_lens.max().item()) + 1 + if decode_attn_mask is None: + return None + if decode_attn_mask.shape[-1] == target_len and torch.all(kv_lens + 1 == target_len).item(): + return decode_attn_mask + compacted = torch.ones( + (decode_attn_mask.shape[0], 1, 1, target_len), + dtype=decode_attn_mask.dtype, + device=decode_attn_mask.device, + ) + for batch_index, kv_len in enumerate(kv_lens.tolist()): + current_len = kv_len + 1 + compacted[batch_index, :, :, -current_len:] = decode_attn_mask[batch_index, :, :, -current_len:] + if not compacted.any().item(): + return None + return compacted + + +def _advance_decode_mask( + decode_attn_mask: Optional[torch.Tensor], + kv_lens: torch.LongTensor, +) -> Optional[torch.Tensor]: + if decode_attn_mask is None: + return None + target_len = int(kv_lens.max().item()) + 2 + advanced = torch.zeros( + (decode_attn_mask.shape[0], 1, 1, target_len), + dtype=decode_attn_mask.dtype, + device=decode_attn_mask.device, + ) + for batch_index, kv_len in enumerate(kv_lens.tolist()): + current_len = kv_len + 1 + next_mask = F.pad(decode_attn_mask[batch_index : batch_index + 1, :, :, -current_len:], (0, 1), value=False) + advanced[batch_index : batch_index + 1, :, :, -next_mask.shape[-1] :] = next_mask + if not advanced.any().item(): + return None + return advanced + + def _sample_per_request( model: Any, active_batch: T2SActiveBatch, @@ -443,16 +561,15 @@ def _sample_per_request( keep_indices: List[int] = [] updated_sequences: List[torch.LongTensor] = [] - step_idx = active_batch.step_idx sampling_keys = [ _sampling_group_key( top_k=state.top_k, top_p=state.top_p, temperature=state.temperature, repetition_penalty=state.repetition_penalty, - trim_eos=False, + trim_eos=int(active_batch.step_indices[batch_index].item()) < 11, ) - for state in active_batch.states + for batch_index, state in enumerate(active_batch.states) ] sampled_items, argmax_tokens = _batched_sample_by_group( logits=logits, @@ -460,6 +577,7 @@ def _sample_per_request( sampling_keys=sampling_keys, ) for batch_index, state in enumerate(active_batch.states): + step_index = int(active_batch.step_indices[batch_index].item()) current_history = active_batch.y_sequences[batch_index] sampled = sampled_items[batch_index] sampled_token = int(sampled[0, 0].item()) @@ -469,7 +587,7 @@ def _sample_per_request( finish_reason: Optional[str] = None if state.early_stop_num != -1 and (new_history.shape[0] - int(active_batch.prefix_lens[batch_index].item())) > state.early_stop_num: finish_reason = "early_stop" - elif step_idx + 1 >= max_steps: + elif step_index + 1 >= max_steps: finish_reason = "max_step" elif sampled_token == model.EOS: finish_reason = "eos_sample" @@ -482,7 +600,7 @@ def _sample_per_request( T2SFinishedItem( request_id=state.request_id, semantic_tokens=new_history[prefix_len:-1].clone(), - finish_idx=step_idx, + finish_idx=step_index, finish_reason=finish_reason, ) ) @@ -493,30 +611,48 @@ def _sample_per_request( return finished_items, keep_indices, updated_sequences +@torch.inference_mode() def decode_one_step( model: Any, active_batch: T2SActiveBatch, max_steps: int, ) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]: - if not active_batch.prefill_done: + was_prefill = not active_batch.prefill_done + if was_prefill: + if active_batch.prefill_attn_mask is None or active_batch.key_padding_mask is None: + raise ValueError("prefill 阶段缺少必要 mask") xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt( active_batch.xy_pos, active_batch.prefill_attn_mask, None ) + active_batch.kv_lens = active_batch.x_lens + active_batch.prefix_lens active_batch.decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False) + if active_batch.k_cache is None or active_batch.v_cache is None or active_batch.kv_lens is None: + raise ValueError("prefill 阶段未生成完整 KV cache") + active_batch.k_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.k_cache] + active_batch.v_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.v_cache] + active_batch.decode_attn_mask = _compact_decode_mask_to_kv_lens(active_batch.decode_attn_mask, active_batch.kv_lens) + active_batch.x = None + active_batch.x_lens = None + active_batch.key_padding_mask = None + active_batch.prefill_attn_mask = None active_batch.prefill_done = True else: + if active_batch.k_cache is None or active_batch.v_cache is None or active_batch.kv_lens is None: + raise ValueError("decode 阶段缺少 KV cache") + batched_decode_attn_mask = None + if active_batch.decode_attn_mask is not None: + batched_decode_attn_mask = _materialize_decode_mask_for_active_batch(active_batch) + if not batched_decode_attn_mask.any().item(): + batched_decode_attn_mask = None xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token( active_batch.xy_pos, active_batch.k_cache, active_batch.v_cache, - active_batch.decode_attn_mask, + batched_decode_attn_mask, ) - if active_batch.decode_attn_mask is not None: - active_batch.decode_attn_mask = F.pad(active_batch.decode_attn_mask, (0, 1), value=False) + active_batch.decode_attn_mask = _advance_decode_mask(active_batch.decode_attn_mask, active_batch.kv_lens) logits = model.ar_predict_layer(xy_dec[:, -1]) - if active_batch.step_idx < 11: - logits = logits[:, :-1] finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps) if len(keep_indices) == 0: @@ -528,16 +664,32 @@ def decode_one_step( active_batch.states = [active_batch.states[i] for i in keep_indices] active_batch.y_sequences = updated_sequences active_batch.prefix_lens = torch.index_select(active_batch.prefix_lens, dim=0, index=keep_tensor) + next_step_indices = torch.index_select(active_batch.step_indices, dim=0, index=keep_tensor) + next_kv_lens = None if active_batch.kv_lens is None else torch.index_select(active_batch.kv_lens, dim=0, index=keep_tensor) + active_batch.step_indices = next_step_indices + 1 + if not was_prefill: + if next_kv_lens is not None: + active_batch.kv_lens = next_kv_lens + 1 + else: + active_batch.kv_lens = next_kv_lens if active_batch.decode_attn_mask is not None: active_batch.decode_attn_mask = torch.index_select(active_batch.decode_attn_mask, dim=0, index=keep_tensor) + if not active_batch.decode_attn_mask.any().item(): + active_batch.decode_attn_mask = None if active_batch.k_cache is not None and active_batch.v_cache is not None: for cache_index in range(len(active_batch.k_cache)): active_batch.k_cache[cache_index] = torch.index_select(active_batch.k_cache[cache_index], dim=0, index=keep_tensor) active_batch.v_cache[cache_index] = torch.index_select(active_batch.v_cache[cache_index], dim=0, index=keep_tensor) + if active_batch.kv_lens is not None: + active_batch.k_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.k_cache] + active_batch.v_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.v_cache] + active_batch.decode_attn_mask = _compact_decode_mask_to_kv_lens( + active_batch.decode_attn_mask, + active_batch.kv_lens, + ) active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences) - active_batch.step_idx += 1 return active_batch, finished_items @@ -583,6 +735,126 @@ def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> ) +def _materialize_decode_mask_for_active_batch( + active_batch: T2SActiveBatch, + target_mask_len: Optional[int] = None, +) -> torch.Tensor: + if active_batch.k_cache is None or active_batch.kv_lens is None: + raise ValueError("active batch 缺少 KV cache 或 kv_lens") + current_mask_len = active_batch.k_cache[0].shape[1] + 1 + if target_mask_len is None: + target_mask_len = current_mask_len + if active_batch.decode_attn_mask is None: + mask = torch.zeros( + (len(active_batch.request_ids), 1, 1, current_mask_len), + dtype=torch.bool, + device=active_batch.k_cache[0].device, + ) + else: + rows: List[torch.Tensor] = [] + for batch_index, kv_len in enumerate(active_batch.kv_lens.tolist()): + row_len = kv_len + 1 + row_mask = _fit_decode_mask_length( + active_batch.decode_attn_mask[batch_index : batch_index + 1], + row_len, + ) + rows.append(_pad_decode_mask_left(row_mask, target_mask_len)) + mask = torch.cat(rows, dim=0) + if target_mask_len != current_mask_len and active_batch.decode_attn_mask is None: + mask = _pad_decode_mask_left(mask, target_mask_len) + return mask + + +@torch.inference_mode() +def run_prefill_active_batch( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]: + if not states: + return None, [] + active_batch = build_prefill_batch(model, states) + return decode_one_step(model, active_batch, max_steps=max_steps) + + +@torch.inference_mode() +def merge_active_batches( + model: Any, + left_batch: Optional[T2SActiveBatch], + right_batch: Optional[T2SActiveBatch], +) -> Optional[T2SActiveBatch]: + if left_batch is None: + return right_batch + if right_batch is None: + return left_batch + if not left_batch.prefill_done or not right_batch.prefill_done: + raise ValueError("只有 prefill 完成后的 active batch 才能 merge") + if left_batch.k_cache is None or left_batch.v_cache is None or right_batch.k_cache is None or right_batch.v_cache is None: + raise ValueError("merge active batch 时缺少 KV cache") + + left_kv_len = int(left_batch.k_cache[0].shape[1]) + right_kv_len = int(right_batch.k_cache[0].shape[1]) + merged_kv_len = max(left_kv_len, right_kv_len) + merged_mask_len = merged_kv_len + 1 + + merged_k_cache: List[torch.Tensor] = [] + merged_v_cache: List[torch.Tensor] = [] + for layer_index in range(len(left_batch.k_cache)): + merged_k_cache.append( + torch.cat( + [ + _pad_cache_left(left_batch.k_cache[layer_index], merged_kv_len), + _pad_cache_left(right_batch.k_cache[layer_index], merged_kv_len), + ], + dim=0, + ) + ) + merged_v_cache.append( + torch.cat( + [ + _pad_cache_left(left_batch.v_cache[layer_index], merged_kv_len), + _pad_cache_left(right_batch.v_cache[layer_index], merged_kv_len), + ], + dim=0, + ) + ) + + merged_decode_attn_mask = torch.cat( + [ + _materialize_decode_mask_for_active_batch(left_batch, merged_mask_len), + _materialize_decode_mask_for_active_batch(right_batch, merged_mask_len), + ], + dim=0, + ) + merged_request_ids = list(left_batch.request_ids) + list(right_batch.request_ids) + merged_states = list(left_batch.states) + list(right_batch.states) + merged_y_sequences = list(left_batch.y_sequences) + list(right_batch.y_sequences) + merged_prefix_lens = torch.cat([left_batch.prefix_lens, right_batch.prefix_lens], dim=0) + if left_batch.kv_lens is None or right_batch.kv_lens is None: + raise ValueError("merge active batch 时缺少 kv_lens") + merged_kv_lens = torch.cat([left_batch.kv_lens, right_batch.kv_lens], dim=0) + merged_decode_attn_mask = _compact_decode_mask_to_kv_lens(merged_decode_attn_mask, merged_kv_lens) + merged_step_indices = torch.cat([left_batch.step_indices, right_batch.step_indices], dim=0) + + return T2SActiveBatch( + request_ids=merged_request_ids, + states=merged_states, + x=None, + x_lens=None, + y_sequences=merged_y_sequences, + prefix_lens=merged_prefix_lens, + xy_pos=build_next_xy_pos(model, merged_y_sequences), + key_padding_mask=None, + prefill_attn_mask=None, + decode_attn_mask=merged_decode_attn_mask, + k_cache=merged_k_cache, + v_cache=merged_v_cache, + kv_lens=merged_kv_lens, + step_indices=merged_step_indices, + prefill_done=True, + ) + + @torch.inference_mode() def run_prefill_step( model: Any, @@ -804,29 +1076,24 @@ def run_scheduler_continuous( max_steps: int, ) -> List[T2SFinishedItem]: pending = sorted(states, key=lambda item: (item.ready_step, item.request_id)) - running_requests: List[T2SRunningRequest] = [] + active_batch: Optional[T2SActiveBatch] = None finished: List[T2SFinishedItem] = [] current_tick = 0 - while pending or running_requests: + while pending or active_batch is not None: admitted: List[T2SRequestState] = [] while pending and pending[0].ready_step <= current_tick: admitted.append(pending.pop(0)) - admitted_running, admitted_finished = run_prefill_step(model, admitted, max_steps=max_steps) + admitted_active_batch, admitted_finished = run_prefill_active_batch(model, admitted, max_steps=max_steps) finished.extend(admitted_finished) + active_batch = merge_active_batches(model, active_batch, admitted_active_batch) - if running_requests: - running_requests, step_finished = run_decode_step_for_running( - model, - running_requests, - max_steps=max_steps, - ) + if active_batch is not None: + active_batch, step_finished = decode_one_step(model, active_batch, max_steps=max_steps) finished.extend(step_finished) - running_requests.extend(admitted_running) - - if not running_requests and pending: + if active_batch is None and pending: current_tick = max(current_tick + 1, pending[0].ready_step) continue diff --git a/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py b/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py new file mode 100644 index 00000000..e2398251 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py @@ -0,0 +1,100 @@ +import os +import re +import sys +from typing import Dict, List, Optional, Tuple + +now_dir = os.getcwd() +sys.path.append(now_dir) + +from text.LangSegmenter import LangSegmenter +from text import cleaned_text_to_sequence +from text.cleaner import clean_text + + +PreparedTextSegmentPayload = Dict[str, object] + + +def split_text_by_language(text: str, language: str) -> Tuple[List[str], List[str]]: + textlist: List[str] = [] + langlist: List[str] = [] + if language == "all_zh": + for tmp in LangSegmenter.getTexts(text, "zh"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_yue": + for tmp in LangSegmenter.getTexts(text, "zh"): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ja": + for tmp in LangSegmenter.getTexts(text, "ja"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ko": + for tmp in LangSegmenter.getTexts(text, "ko"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "en": + langlist.append("en") + textlist.append(text) + elif language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if langlist: + same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or ( + tmp["lang"] != "en" and langlist[-1] != "en" + ) + if same_group: + textlist[-1] += tmp["text"] + continue + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + langlist.append(language) + textlist.append(tmp["text"]) + return textlist, langlist + + +def clean_text_segment(text: str, language: str, version: str) -> Tuple[List[int], Optional[List[int]], str]: + normalized_language = language.replace("all_", "") + phones, word2ph, norm_text = clean_text(text, normalized_language, version) + phones = cleaned_text_to_sequence(phones, version) + return list(phones), None if word2ph is None else list(word2ph), str(norm_text) + + +def preprocess_text_segments_payload( + text: str, + language: str, + version: str, + final: bool = False, +) -> List[PreparedTextSegmentPayload]: + text = re.sub(r" {2,}", " ", text) + textlist, langlist = split_text_by_language(text, language) + payloads: List[PreparedTextSegmentPayload] = [] + total_phones_len = 0 + for segment_text, segment_lang in zip(textlist, langlist): + phones, word2ph, norm_text = clean_text_segment(segment_text, segment_lang, version) + payloads.append( + { + "language": segment_lang.replace("all_", ""), + "phones": phones, + "word2ph": word2ph, + "norm_text": norm_text, + } + ) + total_phones_len += len(phones) + + if not final and total_phones_len < 6: + return preprocess_text_segments_payload("." + text, language, version, final=True) + + return payloads diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 348ddb3f..c6d147cf 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -2,6 +2,7 @@ import warnings warnings.filterwarnings("ignore") import math +from typing import List import torch from torch import nn @@ -1038,6 +1039,67 @@ class SynthesizerTrn(nn.Module): o = self.dec((z * y_mask)[:, :, :], g=ge) return o + @torch.no_grad() + def decode_batched_request_local( + self, + codes: torch.Tensor, + code_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + refer_list: List[torch.Tensor], + noise_scale: float = 0.5, + speed: float = 1, + sv_emb: torch.Tensor | None = None, + ): + batch_size = int(codes.size(1)) + if batch_size <= 0: + raise ValueError("decode_batched_request_local 收到空 batch") + if len(refer_list) != batch_size: + raise ValueError("refer_list 数量与 batch size 不一致") + + refer_lengths = torch.LongTensor([int(item.size(2)) for item in refer_list]).to(codes.device) + max_refer_len = int(refer_lengths.max().item()) + refer_batch = torch.zeros( + (batch_size, int(refer_list[0].size(1)), max_refer_len), + dtype=refer_list[0].dtype, + device=codes.device, + ) + for batch_index, refer in enumerate(refer_list): + refer_batch[batch_index, :, : int(refer.size(2))] = refer.squeeze(0) + refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, max_refer_len), 1).to(refer_batch.dtype) + if self.version == "v1": + ge = self.ref_enc(refer_batch * refer_mask, refer_mask) + else: + ge = self.ref_enc(refer_batch[:, :704] * refer_mask, refer_mask) + if self.is_v2pro: + if sv_emb is None: + raise ValueError("v2Pro batched request-local synthesis 缺少 sv_emb") + ge = ge + self.sv_emb(sv_emb).unsqueeze(-1) + ge = self.prelu(ge) + + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") + y_lengths = code_lengths.to(device=codes.device, dtype=torch.long) * 2 + text_lengths = text_lengths.to(device=text.device, dtype=torch.long) + x, m_p, logs_p, y_mask, _, _ = self.enc_p( + quantized, + y_lengths, + text, + text_lengths, + self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge, + speed, + ) + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=ge, reverse=True) + audio = self.dec((z * y_mask)[:, :, :], g=ge) + upsample_factor = 1 + for up_layer in self.dec.ups: + stride = up_layer.stride[0] if isinstance(up_layer.stride, tuple) else int(up_layer.stride) + upsample_factor *= int(stride) + audio_lengths = y_mask.squeeze(1).sum(dim=1).to(dtype=torch.long) * int(upsample_factor) + return audio, audio_lengths + @torch.no_grad() def decode_streaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): diff --git a/api_v3.py b/api_v3.py index 92f9a3b9..74bc7ac8 100644 --- a/api_v3.py +++ b/api_v3.py @@ -107,6 +107,7 @@ import sys import time import traceback import uuid +from collections import deque from dataclasses import dataclass from pathlib import Path from typing import Generator, List, Union @@ -115,6 +116,10 @@ now_dir = os.getcwd() sys.path.append(now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) +from runtime_preload import preload_text_runtime_deps + +preload_text_runtime_deps() + import argparse import subprocess import wave @@ -128,14 +133,15 @@ import uvicorn from io import BytesIO from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( SchedulerRequestSpec, + T2SActiveBatch, T2SFinishedItem, - T2SRunningRequest, T2SRequestState, - prepare_request_state, - run_decode_step_for_running, - run_prefill_step, + merge_active_batches, + decode_one_step, + run_prefill_active_batch, run_scheduler_continuous, ) from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names @@ -238,39 +244,71 @@ class SchedulerPendingJob: request_id: str state: T2SRequestState done_event: threading.Event + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None enqueue_time: float speed_factor: float sample_steps: int media_type: str - prepare_ms: float = 0.0 prepare_wall_ms: float = 0.0 + prepare_profile_total_ms: float = 0.0 first_schedule_time: float | None = None prefill_ms: float = 0.0 + merge_ms: float = 0.0 decode_ms: float = 0.0 + finalize_wait_ms: float = 0.0 synth_ms: float = 0.0 pack_ms: float = 0.0 decode_steps: int = 0 + result_ready_time: float | None = None result: dict | None = None sample_rate: int | None = None audio_data: np.ndarray | None = None error: str | None = None +@dataclass +class SchedulerFinalizeTask: + request_id: str + item: T2SFinishedItem + enqueued_time: float + + class SchedulerDebugWorker: def __init__(self, tts: TTS, max_steps: int = 1500, micro_batch_wait_ms: int = 5): self.tts = tts self.max_steps = max_steps self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0 + self.prepare_coordinator = PrepareCoordinator(tts) self.condition = threading.Condition() self.prepare_inflight = 0 self.prepare_peak_inflight = 0 + self.finalize_condition = threading.Condition() + self.finalize_pending_tasks: deque[SchedulerFinalizeTask] = deque() + self.finalize_pending_peak = 0 + self.finalize_inflight = 0 + self.finalize_inflight_peak = 0 + self.finalize_workers = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) + self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() + self.finalize_batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) + self.finalize_batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) self.pending_jobs: List[SchedulerPendingJob] = [] - self.running_requests: List[T2SRunningRequest] = [] + self.active_batch: T2SActiveBatch | None = None self.job_map: dict[str, SchedulerPendingJob] = {} self.total_finished = 0 self.total_submitted = 0 self.worker_thread = threading.Thread(target=self._run_loop, name="t2s-scheduler-debug-worker", daemon=True) self.worker_thread.start() + self.finalize_threads = [ + threading.Thread( + target=self._run_finalize_loop, + name=f"t2s-scheduler-finalize-{worker_index}", + daemon=True, + ) + for worker_index in range(self.finalize_workers) + ] + for finalize_thread in self.finalize_threads: + finalize_thread.start() def _sync_device(self) -> None: try: @@ -283,20 +321,7 @@ class SchedulerDebugWorker: pass def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState: - with self.condition: - self.prepare_inflight += 1 - prepare_inflight_on_enter = self.prepare_inflight - if self.prepare_inflight > self.prepare_peak_inflight: - self.prepare_peak_inflight = self.prepare_inflight - prepare_peak_inflight = self.prepare_peak_inflight - try: - state = prepare_request_state(self.tts, spec) - state.prepare_profile["worker_prepare_inflight_on_enter"] = float(prepare_inflight_on_enter) - state.prepare_profile["worker_prepare_peak_inflight"] = float(prepare_peak_inflight) - return state - finally: - with self.condition: - self.prepare_inflight = max(0, self.prepare_inflight - 1) + raise RuntimeError("prepare_state sync path has been replaced by PrepareCoordinator") def submit( self, @@ -304,27 +329,47 @@ class SchedulerDebugWorker: speed_factor: float, sample_steps: int, media_type: str, - prepare_ms: float, prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, ) -> SchedulerPendingJob: job = SchedulerPendingJob( request_id=state.request_id, state=state, done_event=threading.Event(), + done_loop=done_loop, + done_future=done_future, enqueue_time=time.perf_counter(), speed_factor=float(speed_factor), sample_steps=int(sample_steps), media_type=media_type, - prepare_ms=float(prepare_ms), prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), ) with self.condition: self.pending_jobs.append(job) self.job_map[job.request_id] = job self.total_submitted += 1 self.condition.notify_all() + with self.finalize_condition: + self.finalize_condition.notify_all() return job + async def prepare_state_async(self, spec: SchedulerRequestSpec) -> T2SRequestState: + state, _, _ = await self.prepare_coordinator.prepare_state_profiled_async(spec, time.perf_counter()) + return state + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + return await asyncio.gather(*[self.prepare_state_async(spec) for spec in specs]) + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_coordinator.prepare_state_profiled_async(spec, prepare_submit_at) + def _mark_prefill_started(self, jobs: List[SchedulerPendingJob], started_at: float) -> None: with self.condition: for job in jobs: @@ -340,6 +385,14 @@ class SchedulerDebugWorker: if tracked_job is not None: tracked_job.prefill_ms += elapsed_ms + def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + elapsed_ms = elapsed_s * 1000.0 + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.merge_ms += elapsed_ms + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: elapsed_ms = elapsed_s * 1000.0 with self.condition: @@ -349,16 +402,30 @@ class SchedulerDebugWorker: job.decode_ms += elapsed_ms job.decode_steps += 1 + def _add_finalize_wait_ms(self, request_ids: List[str], elapsed_ms: float) -> None: + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.finalize_wait_ms += elapsed_ms + def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: - semantic_tokens = item.semantic_tokens.unsqueeze(0).unsqueeze(0).to(self.tts.configs.device) - phones = job.state.phones.unsqueeze(0).to(self.tts.configs.device) + semantic_tokens = item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0).to(self.tts.configs.device) + phones = job.state.phones.detach().clone().unsqueeze(0).to(self.tts.configs.device) + prompt_semantic = job.state.prompt_semantic.detach().clone() + prompt_phones = job.state.prompt_phones.detach().clone() + refer_spec = ( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ) + raw_audio = job.state.raw_audio.detach().clone() audio_fragment = self.tts.synthesize_audio_request_local( semantic_tokens=semantic_tokens, phones=phones, - prompt_semantic=job.state.prompt_semantic, - prompt_phones=job.state.prompt_phones, - refer_spec=job.state.refer_spec, - raw_audio=job.state.raw_audio, + prompt_semantic=prompt_semantic, + prompt_phones=prompt_phones, + refer_spec=refer_spec, + raw_audio=raw_audio, raw_sr=job.state.raw_sr, speed=float(job.speed_factor), sample_steps=int(job.sample_steps), @@ -375,6 +442,11 @@ class SchedulerDebugWorker: ) def get_state(self) -> dict: + with self.finalize_condition: + finalize_pending = len(self.finalize_pending_tasks) + finalize_pending_peak = self.finalize_pending_peak + finalize_inflight = self.finalize_inflight + finalize_inflight_peak = self.finalize_inflight_peak with self.condition: bert_stage = self.tts.prepare_bert_stage_limiter.snapshot() ref_audio_stage = self.tts.prepare_ref_audio_stage_limiter.snapshot() @@ -388,12 +460,24 @@ class SchedulerDebugWorker: if self.tts.prepare_ref_semantic_batch_worker is None else self.tts.prepare_ref_semantic_batch_worker.snapshot() ) + prepare_coordinator_state = self.prepare_coordinator.snapshot() return { "pending_jobs": len(self.pending_jobs), - "running_requests": len(self.running_requests), - "prepare_inflight": self.prepare_inflight, - "prepare_peak_inflight": self.prepare_peak_inflight, + "running_requests": 0 if self.active_batch is None else len(self.active_batch.request_ids), + "prepare_inflight": prepare_coordinator_state["inflight"], + "prepare_peak_inflight": prepare_coordinator_state["peak_inflight"], + "finalize_pending": finalize_pending, + "finalize_pending_peak": finalize_pending_peak, + "finalize_inflight": finalize_inflight, + "finalize_inflight_peak": finalize_inflight_peak, + "finalize_workers": self.finalize_workers, + "finalize_mode": self.finalize_mode, + "finalize_batch_max_items": self.finalize_batch_max_items, + "finalize_batch_wait_ms": self.finalize_batch_wait_s * 1000.0, + "prepare_request_executor_workers": 0, "prepare_text_cpu_workers": int(getattr(self.tts, "prepare_text_cpu_workers", 0)), + "prepare_text_feature_workers": int(prepare_coordinator_state["text_feature_workers"]), + "prepare_ref_audio_workers": int(prepare_coordinator_state["ref_audio_workers"]), "prepare_bert_stage": bert_stage, "prepare_bert_batch_worker": bert_batch_worker, "prepare_ref_audio_stage": ref_audio_stage, @@ -405,59 +489,217 @@ class SchedulerDebugWorker: "micro_batch_wait_ms": int(self.micro_batch_wait_s * 1000), } - def _finalize_finished(self, items: List[T2SFinishedItem]) -> None: + def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: if not items: return - jobs_to_finalize: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + tasks: List[SchedulerFinalizeTask] = [] + enqueued_time = time.perf_counter() with self.condition: for item in items: job = self.job_map.get(item.request_id) if job is not None: - jobs_to_finalize.append((job, item)) + tasks.append( + SchedulerFinalizeTask( + request_id=item.request_id, + item=item, + enqueued_time=enqueued_time, + ) + ) + if not tasks: + return + with self.finalize_condition: + self.finalize_pending_tasks.extend(tasks) + if len(self.finalize_pending_tasks) > self.finalize_pending_peak: + self.finalize_pending_peak = len(self.finalize_pending_tasks) + self.finalize_condition.notify_all() - for job, item in jobs_to_finalize: + @staticmethod + def _finalize_batch_key(job: SchedulerPendingJob) -> tuple[float, int]: + return (round(float(job.speed_factor), 6), int(job.sample_steps)) + + def _take_finalize_task_batch(self) -> List[SchedulerFinalizeTask]: + with self.finalize_condition: + while not self.finalize_pending_tasks: + self.finalize_condition.wait() + if self.finalize_mode == "after_t2s_drain": + while not self._is_t2s_drained(): + self.finalize_condition.wait(timeout=self.micro_batch_wait_s) + task = self.finalize_pending_tasks.popleft() + selected_tasks = [task] + batch_key = None + with self.condition: + first_job = self.job_map.get(task.request_id) + if first_job is not None: + batch_key = self._finalize_batch_key(first_job) + batch_deadline = time.perf_counter() + self.finalize_batch_wait_s + while len(selected_tasks) < self.finalize_batch_max_items: + if batch_key is None: + break + matched_index = None + for pending_index, pending_task in enumerate(self.finalize_pending_tasks): + with self.condition: + pending_job = self.job_map.get(pending_task.request_id) + if pending_job is None: + matched_index = pending_index + break + if self._finalize_batch_key(pending_job) == batch_key: + matched_index = pending_index + break + if matched_index is not None: + selected_tasks.append(self.finalize_pending_tasks[matched_index]) + del self.finalize_pending_tasks[matched_index] + continue + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.finalize_condition.wait(timeout=remaining) + self.finalize_inflight += len(selected_tasks) + if self.finalize_inflight > self.finalize_inflight_peak: + self.finalize_inflight_peak = self.finalize_inflight + return selected_tasks + + def _finalize_task_done(self, count: int) -> None: + with self.finalize_condition: + self.finalize_inflight = max(0, self.finalize_inflight - count) + + def _is_t2s_drained(self) -> bool: + with self.condition: + return ( + self.active_batch is None + and not self.pending_jobs + and self.prepare_inflight <= 0 + ) + + def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: + finished_at = time.perf_counter() + with self.condition: + if self.job_map.get(item.request_id) is not job: + return + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + worker_residual_ms = max( + 0.0, + worker_total_ms + - queue_wait_ms + - job.prefill_ms + - job.merge_ms + - job.decode_ms + - job.finalize_wait_ms + - job.synth_ms, + ) + worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + job.result_ready_time = finished_at + prepare_profile = dict(job.state.prepare_profile) + job.result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "prepare_ms": job.prepare_wall_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile_total_ms": job.prepare_profile_total_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "merge_ms": job.merge_ms, + "decode_ms": job.decode_ms, + "finalize_wait_ms": job.finalize_wait_ms, + "synth_ms": job.synth_ms, + "worker_residual_ms": worker_residual_ms, + "worker_other_ms": worker_other_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.done_event.set() + self._notify_done_future(job) + self.job_map.pop(item.request_id, None) + self.total_finished += 1 + + def _synthesize_finished_audio_batch( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> List[tuple[int, np.ndarray]]: + semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] + phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] + refer_specs = [] + speeds = [] + sample_steps_list = [] + for job, _ in jobs_and_items: + refer_specs.append( + ( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ) + ) + speeds.append(float(job.speed_factor)) + sample_steps_list.append(int(job.sample_steps)) + audio_fragments = self.tts.synthesize_audio_requests_local_batched( + semantic_tokens_list=semantic_tokens_list, + phones_list=phones_list, + refer_specs=refer_specs, + speeds=speeds, + sample_steps_list=sample_steps_list, + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + results: List[tuple[int, np.ndarray]] = [] + for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): + results.append( + self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + ) + return results + + def _run_finalize_loop(self) -> None: + while True: + tasks = self._take_finalize_task_batch() try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + finalize_wait_request_ids: List[str] = [] + with self.condition: + for task in tasks: + job = self.job_map.get(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + finalize_wait_request_ids.append(task.request_id) + if not jobs_and_items: + continue + now = time.perf_counter() + for task in tasks: + self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) self._sync_device() synth_start = time.perf_counter() - sample_rate, audio_data = self._synthesize_finished_audio(job, item) + if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: + job, item = jobs_and_items[0] + batch_results = [self._synthesize_finished_audio(job, item)] + else: + batch_results = self._synthesize_finished_audio_batch(jobs_and_items) self._sync_device() synth_ms = (time.perf_counter() - synth_start) * 1000.0 + with self.condition: + for job, _ in jobs_and_items: + tracked_job = self.job_map.get(job.request_id) + if tracked_job is not None: + tracked_job.synth_ms += synth_ms + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) except Exception as exc: - self._finalize_error([item.request_id], str(exc)) - continue - - finished_at = time.perf_counter() - with self.condition: - if self.job_map.get(item.request_id) is not job: - continue - queue_wait_ms = 0.0 - if job.first_schedule_time is not None: - queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) - worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) - job.synth_ms += synth_ms - job.sample_rate = int(sample_rate) - job.audio_data = audio_data - prepare_profile = dict(job.state.prepare_profile) - job.result = { - "request_id": item.request_id, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "finish_reason": item.finish_reason, - "prepare_ms": job.prepare_ms, - "prepare_wall_ms": job.prepare_wall_ms, - "prepare_profile": prepare_profile, - "queue_wait_ms": queue_wait_ms, - "prefill_ms": job.prefill_ms, - "decode_ms": job.decode_ms, - "synth_ms": job.synth_ms, - "worker_total_ms": worker_total_ms, - "decode_steps": int(job.decode_steps), - "sample_rate": int(sample_rate), - "media_type": job.media_type, - } - job.done_event.set() - self.job_map.pop(item.request_id, None) - self.total_finished += 1 + self._finalize_error([task.request_id for task in tasks], str(exc)) + finally: + self._finalize_task_done(len(tasks)) def _finalize_error(self, request_ids: List[str], error: str) -> None: if not request_ids: @@ -469,12 +711,28 @@ class SchedulerDebugWorker: continue job.error = error job.done_event.set() + self._notify_done_future(job) self.job_map.pop(request_id, None) self.total_finished += 1 + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(True) + + def _notify_done_future(self, job: SchedulerPendingJob) -> None: + if job.done_loop is None or job.done_future is None: + return + try: + job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) + except RuntimeError: + pass + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: with self.condition: - if not self.pending_jobs and not self.running_requests: + if not self.pending_jobs and self.active_batch is None: self.condition.wait(timeout=self.micro_batch_wait_s) elif wait_for_batch and self.pending_jobs: self.condition.wait(timeout=self.micro_batch_wait_s) @@ -482,11 +740,13 @@ class SchedulerDebugWorker: return [] pending = list(self.pending_jobs) self.pending_jobs.clear() + with self.finalize_condition: + self.finalize_condition.notify_all() return pending def _run_loop(self) -> None: while True: - wait_for_batch = len(self.running_requests) == 0 + wait_for_batch = self.active_batch is None pending_jobs = self._take_pending_snapshot(wait_for_batch=wait_for_batch) if pending_jobs: @@ -494,37 +754,54 @@ class SchedulerDebugWorker: self._sync_device() prefill_start = time.perf_counter() self._mark_prefill_started(pending_jobs, prefill_start) - admitted_running, admitted_finished = run_prefill_step( + admitted_active_batch, admitted_finished = run_prefill_active_batch( self.tts.t2s_model.model, [job.state for job in pending_jobs], max_steps=self.max_steps, ) self._sync_device() self._add_prefill_time(pending_jobs, time.perf_counter() - prefill_start) - self._finalize_finished(admitted_finished) - self.running_requests.extend(admitted_running) + self._enqueue_finalize_finished(admitted_finished) + merge_start = time.perf_counter() + self.active_batch = merge_active_batches( + self.tts.t2s_model.model, + self.active_batch, + admitted_active_batch, + ) + self._add_merge_time( + [] if self.active_batch is None else list(self.active_batch.request_ids), + time.perf_counter() - merge_start, + ) + with self.finalize_condition: + self.finalize_condition.notify_all() except Exception as exc: self._finalize_error([job.request_id for job in pending_jobs], str(exc)) - if self.running_requests: + if self.active_batch is not None: try: - active_request_ids = [item.state.request_id for item in self.running_requests] + active_request_ids = [state.request_id for state in self.active_batch.states] self._sync_device() decode_start = time.perf_counter() - self.running_requests, step_finished = run_decode_step_for_running( + self.active_batch, step_finished = decode_one_step( self.tts.t2s_model.model, - self.running_requests, + self.active_batch, max_steps=self.max_steps, ) self._sync_device() self._add_decode_time(active_request_ids, time.perf_counter() - decode_start) - self._finalize_finished(step_finished) + self._enqueue_finalize_finished(step_finished) + with self.finalize_condition: + self.finalize_condition.notify_all() except Exception as exc: self._finalize_error(active_request_ids, str(exc)) - self.running_requests = [] + self.active_batch = None + with self.finalize_condition: + self.finalize_condition.notify_all() continue if not pending_jobs: + with self.finalize_condition: + self.finalize_condition.notify_all() time.sleep(self.micro_batch_wait_s) @@ -788,10 +1065,6 @@ def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: ] -def prepare_scheduler_states_batch(specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: - return [scheduler_debug_worker.prepare_state(spec) for spec in specs] - - def build_scheduler_submit_spec(request: Scheduler_Submit_Request) -> SchedulerRequestSpec: payload = request.dict() request_id = payload["request_id"] or f"job_{uuid.uuid4().hex[:12]}" @@ -845,7 +1118,7 @@ async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request): try: set_scheduler_seed(request.seed) specs = build_scheduler_request_specs(request.requests) - states = await asyncio.to_thread(prepare_scheduler_states_batch, specs) + states = await scheduler_debug_worker.prepare_states_batch_async(specs) finished = run_scheduler_continuous(tts_pipeline.t2s_model.model, states, max_steps=int(request.max_steps)) return JSONResponse( status_code=200, @@ -867,20 +1140,51 @@ async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request): async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): try: request_start = time.perf_counter() + prepare_start = request_start spec = build_scheduler_submit_spec(request) - prepare_start = time.perf_counter() - state = await asyncio.to_thread(scheduler_debug_worker.prepare_state, spec) - prepare_wall_ms = (time.perf_counter() - prepare_start) * 1000.0 - prepare_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms)) + spec_ready_at = time.perf_counter() + prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) + state, prepare_exec_started_at, prepare_exec_finished_at = await scheduler_debug_worker.prepare_state_profiled_async( + spec, + spec_ready_at, + ) + prepare_end = time.perf_counter() + prepare_wall_ms = (prepare_end - prepare_start) * 1000.0 + prepare_profile_total_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms)) + prepare_profile_wall_ms = float(state.prepare_profile.get("wall_total_ms", prepare_profile_total_ms)) + prepare_executor_queue_ms = float( + state.prepare_profile.get("executor_queue_ms", max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0)) + ) + prepare_executor_run_ms = float( + state.prepare_profile.get( + "executor_run_wall_ms", + max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0), + ) + ) + prepare_other_ms = max( + 0.0, + prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_profile_wall_ms, + ) + loop = asyncio.get_running_loop() + done_future = loop.create_future() job = scheduler_debug_worker.submit( state, speed_factor=float(request.speed_factor), sample_steps=int(request.sample_steps), media_type=request.media_type, - prepare_ms=prepare_ms, prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, ) - timeout_ok = await asyncio.to_thread(job.done_event.wait, float(request.timeout_sec)) + api_after_prepare_ms = max(0.0, (job.enqueue_time - prepare_end) * 1000.0) + timeout_ok = False + try: + await asyncio.wait_for(asyncio.shield(done_future), timeout=float(request.timeout_sec)) + timeout_ok = True + except asyncio.TimeoutError: + timeout_ok = False + wait_return_at = time.perf_counter() if not timeout_ok: return JSONResponse( status_code=202, @@ -888,8 +1192,10 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): "message": "queued", "request_id": job.request_id, "timings": { - "prepare_ms": prepare_ms, + "prepare_ms": prepare_wall_ms, "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "api_after_prepare_ms": api_after_prepare_ms, "request_elapsed_ms": max(0.0, (time.perf_counter() - request_start) * 1000.0), }, "worker_state": scheduler_debug_worker.get_state(), @@ -911,9 +1217,13 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): ) pack_start = time.perf_counter() audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() - pack_ms = (time.perf_counter() - pack_start) * 1000.0 + pack_end = time.perf_counter() + pack_ms = (pack_end - pack_start) * 1000.0 job.pack_ms = pack_ms - request_total_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) + api_wait_result_ms = 0.0 + if job.result_ready_time is not None: + api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0) + worker_total_ms = float(job.result["worker_total_ms"]) if job.result is not None else 0.0 headers = { "X-Request-Id": job.request_id, "X-Semantic-Len": str(job.result["semantic_len"]) if job.result is not None else "0", @@ -921,16 +1231,32 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): "X-Queue-Wait-Ms": ( f"{float(job.result['queue_wait_ms']):.3f}" if job.result is not None else "0.000" ), - "X-Prepare-Ms": f"{prepare_ms:.3f}", + "X-Prepare-Ms": f"{prepare_wall_ms:.3f}", "X-Prepare-Wall-Ms": f"{prepare_wall_ms:.3f}", + "X-Prepare-Spec-Build-Ms": f"{prepare_spec_build_ms:.3f}", + "X-Prepare-Executor-Queue-Ms": f"{prepare_executor_queue_ms:.3f}", + "X-Prepare-Admission-Wait-Ms": ( + f"{float(job.result['prepare_profile'].get('prepare_admission_wait_ms', 0.0)):.3f}" + if job.result is not None + else "0.000" + ), + "X-Prepare-Executor-Run-Ms": f"{prepare_executor_run_ms:.3f}", + "X-Prepare-Profile-Total-Ms": f"{prepare_profile_total_ms:.3f}", + "X-Prepare-Profile-Wall-Ms": f"{prepare_profile_wall_ms:.3f}", + "X-Prepare-Other-Ms": f"{prepare_other_ms:.3f}", + "X-Api-After-Prepare-Ms": f"{api_after_prepare_ms:.3f}", "X-Prefill-Ms": f"{float(job.result['prefill_ms']):.3f}" if job.result is not None else "0.000", + "X-Merge-Ms": f"{float(job.result['merge_ms']):.3f}" if job.result is not None else "0.000", "X-Decode-Ms": f"{float(job.result['decode_ms']):.3f}" if job.result is not None else "0.000", + "X-Finalize-Wait-Ms": f"{float(job.result['finalize_wait_ms']):.3f}" if job.result is not None else "0.000", "X-Synth-Ms": f"{float(job.result['synth_ms']):.3f}" if job.result is not None else "0.000", + "X-Worker-Residual-Ms": f"{float(job.result['worker_residual_ms']):.3f}" if job.result is not None else "0.000", + "X-Worker-Other-Ms": f"{float(job.result['worker_other_ms']):.3f}" if job.result is not None else "0.000", "X-Pack-Ms": f"{pack_ms:.3f}", "X-Worker-Total-Ms": ( f"{float(job.result['worker_total_ms']):.3f}" if job.result is not None else "0.000" ), - "X-Request-Total-Ms": f"{request_total_ms:.3f}", + "X-Api-Wait-Result-Ms": f"{api_wait_result_ms:.3f}", "X-Decode-Steps": str(job.result["decode_steps"]) if job.result is not None else "0", } if job.result is not None: @@ -939,16 +1265,48 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): { "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('prompt_text_cpu_preprocess_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('text_cpu_preprocess_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('prompt_text_cpu_queue_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('text_cpu_queue_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('prompt_text_feature_queue_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('text_feature_queue_ms', 0.0)):.3f}", "X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}", "X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_admission_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('text_bert_admission_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_queue_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('text_bert_queue_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_collect_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('text_bert_batch_collect_wait_ms', 0.0)):.3f}", "X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}", "X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str( + int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0)) + ), + "X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str( + int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0)) + ), + "X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str( + int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0)) + ), + "X-Prepare-Target-Bert-Pending-On-Collect-Peak": str( + int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0)) + ), + "X-Prepare-Prompt-Bert-High-Pressure-Peak": str( + int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0)) + ), + "X-Prepare-Target-Bert-High-Pressure-Peak": str( + int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0)) + ), "X-Prepare-Prompt-Bert-Batch-Size-Peak": str( int(prepare_profile.get("prompt_text_bert_batch_size_peak", 0.0)) ), "X-Prepare-Target-Bert-Batch-Size-Peak": str( int(prepare_profile.get("text_bert_batch_size_peak", 0.0)) ), + "X-Prepare-Prompt-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_window_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('text_bert_batch_window_ms', 0.0)):.3f}", "X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}", "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), "X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}", @@ -964,13 +1322,22 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): "X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}", "X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}", "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", - "X-Prepare-Profile-Wall-Ms": f"{float(prepare_profile.get('wall_total_ms', 0.0)):.3f}", "X-Prepare-Inflight-On-Enter": str( int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0)) ), "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), } ) + response_ready_at = time.perf_counter() + response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0) + request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) + request_other_ms = max( + 0.0, + request_total_ms - prepare_wall_ms - api_after_prepare_ms - worker_total_ms - api_wait_result_ms - pack_ms, + ) + headers["X-Response-Overhead-Ms"] = f"{response_overhead_ms:.3f}" + headers["X-Request-Other-Ms"] = f"{request_other_ms:.3f}" + headers["X-Request-Total-Ms"] = f"{request_total_ms:.3f}" return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers) except Exception as e: return JSONResponse(