Refactor TTS and scheduler components to enhance text processing and batching capabilities. Introduce PrepareCoordinator for managing text feature preparation asynchronously, and update SchedulerDebugWorker to support new finalize task management. Implement batch processing in PrepareBertBatchWorker with improved admission control and profiling metrics. Add text CPU preprocessing utilities for better text segmentation and normalization.

This commit is contained in:
baicai-1145 2026-03-10 06:58:53 +08:00
parent a45e171ff5
commit 827d6ea47c
8 changed files with 1811 additions and 272 deletions

View File

@ -1,26 +1,27 @@
import gc
import concurrent.futures
import math
import os
import random
import sys
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
import torchaudio
from tqdm import tqdm
now_dir = os.getcwd()
sys.path.append(now_dir)
import os
from typing import List, Tuple, Union
from runtime_preload import preload_text_runtime_deps
preload_text_runtime_deps()
import ffmpeg
import librosa
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
import yaml
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from BigVGAN.bigvgan import BigVGAN
@ -30,6 +31,7 @@ from module.models import SynthesizerTrn, SynthesizerTrnV3, Generator
from peft import LoraConfig, get_peft_model
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
from transformers import AutoModelForMaskedLM, AutoTokenizer
from tqdm import tqdm
from tools.audio_sr import AP_BWE
from tools.i18n.i18n import I18nAuto, scan_language_list
@ -449,20 +451,21 @@ class TTS:
"overlapped_len": None,
}
self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1")))
self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "2")))
self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4")))
self.prepare_bert_batch_worker = None
self.prepare_ref_semantic_batch_worker = None
default_text_cpu_workers = 16
self.prepare_text_cpu_workers = max(
0,
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", str(default_text_cpu_workers))),
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", "0")),
)
self.prepare_text_cpu_executor = None
if self.prepare_text_cpu_workers > 0:
self.prepare_text_cpu_executor = ThreadPoolExecutor(
self.prepare_text_cpu_executor = (
concurrent.futures.ThreadPoolExecutor(
max_workers=self.prepare_text_cpu_workers,
thread_name_prefix="prepare-text-cpu",
)
if self.prepare_text_cpu_workers > 0
else None
)
self._init_models()
@ -475,6 +478,20 @@ class TTS:
batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_WINDOW_MS", "5")),
max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_ITEMS", "16")),
max_batch_tokens=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_TOKENS", "4096")),
max_pending_tasks=int(os.environ.get("GPTSOVITS_PREPARE_BERT_MAX_PENDING_TASKS", "0")),
admission_poll_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_ADMISSION_POLL_MS", "1")),
high_pressure_pending_threshold=int(
os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_PENDING_THRESHOLD", "0")
),
high_pressure_batch_window_ms=int(
os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_BATCH_WINDOW_MS", "1")
),
high_pressure_max_batch_items=int(
os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_MAX_ITEMS", "32")
),
high_pressure_max_batch_tokens=int(
os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_MAX_TOKENS", "8192")
),
)
if os.environ.get("GPTSOVITS_PREPARE_REF_BATCHING", "0") != "0":
ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_SAMPLES")
@ -830,13 +847,23 @@ class TTS:
return codes[0, 0].to(self.configs.device)
@torch.inference_mode()
def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
cpu_prepare_start = time.perf_counter()
wav16k = prepare_prompt_semantic_wav16k(
raw_audio=raw_audio,
raw_sr=raw_sr,
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
)
return self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
forward_start = time.perf_counter()
prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
return prompt_semantic, cpu_prepare_ms, forward_ms
@torch.inference_mode()
def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
prompt_semantic, _, _ = self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr)
return prompt_semantic
def extract_prompt_semantic(self, ref_wav_path: str):
raw_audio, raw_sr = self._load_ref_audio_raw(ref_wav_path)
@ -887,7 +914,9 @@ class TTS:
if self.prepare_ref_semantic_batch_worker is None:
with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats:
prompt_semantic_start = time.perf_counter()
prompt_semantic = self._extract_prompt_semantic_from_raw(raw_audio, raw_sr)
prompt_semantic, prompt_semantic_cpu_prepare_ms, prompt_semantic_forward_ms = (
self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr)
)
prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0
ref_spec_start = time.perf_counter()
refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
@ -897,8 +926,8 @@ class TTS:
audio_stage_inflight_peak = float(limiter_stats["peak_inflight"])
prompt_semantic_profile = {
"prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]),
"prompt_semantic_cpu_prepare_ms": 0.0,
"prompt_semantic_forward_ms": prompt_semantic_ms,
"prompt_semantic_cpu_prepare_ms": float(prompt_semantic_cpu_prepare_ms),
"prompt_semantic_forward_ms": float(prompt_semantic_forward_ms),
"prompt_semantic_scatter_ms": 0.0,
"prompt_semantic_stage_slots": float(limiter_stats["slots"]),
"prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
@ -1012,6 +1041,32 @@ class TTS:
text, language, self.configs.version, profile=profile
)
def prepare_text_segments(self, text: str, language: str):
return self.text_preprocessor.preprocess_text_segments(text, language, self.configs.version)
def build_text_features_from_segments(self, prepared_segments, profile: dict | None = None):
return self.text_preprocessor.build_phones_and_bert_from_segments(prepared_segments, profile=profile)
async def build_text_features_from_segments_async(self, prepared_segments, profile: dict | None = None):
return await self.text_preprocessor.build_phones_and_bert_from_segments_async(
prepared_segments,
profile=profile,
)
async def build_text_feature_pair_from_segments_async(
self,
prompt_segments,
target_segments,
prompt_profile: dict | None = None,
target_profile: dict | None = None,
):
return await self.text_preprocessor.build_phones_and_bert_pair_from_segments_async(
prompt_segments,
target_segments,
prompt_profile=prompt_profile,
target_profile=target_profile,
)
def _set_ref_audio_path(self, ref_audio_path):
self.prompt_cache["ref_audio_path"] = ref_audio_path
@ -2011,6 +2066,79 @@ class TTS:
sample_steps=sample_steps,
)
@torch.inference_mode()
def synthesize_audio_requests_local_batched(
self,
semantic_tokens_list: List[torch.Tensor],
phones_list: List[torch.Tensor],
refer_specs: List[tuple],
speeds: List[float] | None = None,
sample_steps_list: List[int] | None = None,
) -> List[torch.Tensor]:
batch_size = len(semantic_tokens_list)
if batch_size == 0:
return []
if len(phones_list) != batch_size or len(refer_specs) != batch_size:
raise ValueError("batched request-local synthesis 输入长度不一致")
if speeds is None:
speeds = [1.0] * batch_size
if sample_steps_list is None:
sample_steps_list = [32] * batch_size
if len(speeds) != batch_size or len(sample_steps_list) != batch_size:
raise ValueError("batched request-local synthesis 参数长度不一致")
first_speed = float(speeds[0])
first_sample_steps = int(sample_steps_list[0])
if any(abs(float(item) - first_speed) > 1e-6 for item in speeds):
raise ValueError("batched request-local synthesis 目前要求 speed 一致")
if any(int(item) != first_sample_steps for item in sample_steps_list):
raise ValueError("batched request-local synthesis 目前要求 sample_steps 一致")
if self.configs.use_vocoder:
raise NotImplementedError("request-local batched VITS synthesis 暂不支持 vocoder 模型")
device = self.configs.device
max_semantic_len = max(int(item.shape[-1]) for item in semantic_tokens_list)
max_phone_len = max(int(item.shape[-1]) for item in phones_list)
semantic_batch = torch.zeros((1, batch_size, max_semantic_len), dtype=torch.long, device=device)
phone_batch = torch.zeros((batch_size, max_phone_len), dtype=torch.long, device=device)
semantic_lengths = []
phone_lengths = []
refer_audio_specs: List[torch.Tensor] = []
sv_emb_batch = None
sv_emb_list: List[torch.Tensor] = []
for batch_index, semantic_tokens in enumerate(semantic_tokens_list):
semantic_len = int(semantic_tokens.shape[-1])
phone_len = int(phones_list[batch_index].shape[-1])
semantic_batch[0, batch_index, :semantic_len] = semantic_tokens.to(device=device, dtype=torch.long)
phone_batch[batch_index, :phone_len] = phones_list[batch_index].to(device=device, dtype=torch.long)
semantic_lengths.append(semantic_len)
phone_lengths.append(phone_len)
refer_audio_spec, audio_tensor = refer_specs[batch_index]
refer_audio_specs.append(refer_audio_spec.to(dtype=self.precision, device=device))
if self.is_v2pro:
if audio_tensor is None:
raise ValueError(i18n("v2Pro request-local batched synthesis 缺少 16k 参考音频"))
sv_emb_list.append(self.sv_model.compute_embedding3(audio_tensor).to(device))
if self.is_v2pro:
sv_emb_batch = torch.cat(sv_emb_list, dim=0)
audio_batch, audio_lengths = self.vits_model.decode_batched_request_local(
codes=semantic_batch,
code_lengths=torch.LongTensor(semantic_lengths).to(device),
text=phone_batch,
text_lengths=torch.LongTensor(phone_lengths).to(device),
refer_list=refer_audio_specs,
speed=first_speed,
sv_emb=sv_emb_batch,
)
audios: List[torch.Tensor] = []
for batch_index in range(batch_size):
audio_len = int(audio_lengths[batch_index].item())
audios.append(audio_batch[batch_index, 0, :audio_len].detach())
return audios
def using_vocoder_synthesis_batched_infer(
self,
idx_list: List[int],

View File

@ -1,8 +1,10 @@
import asyncio
import os
import sys
import threading
import time
from contextlib import contextmanager
from dataclasses import dataclass
from tqdm import tqdm
@ -13,12 +15,13 @@ import re
import torch
from text.LangSegmenter import LangSegmenter
from text import chinese
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
from text.cleaner import clean_text
from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker
from TTS_infer_pack.text_cpu_preprocess import preprocess_text_segments_payload
from tools.i18n.i18n import I18nAuto, scan_language_list
@ -92,6 +95,14 @@ class StageLimiter:
}
@dataclass
class PreparedTextSegment:
language: str
phones: List[int]
word2ph: Optional[List[int]]
norm_text: str
class TextPreprocessor:
def __init__(
self,
@ -149,7 +160,7 @@ class TextPreprocessor:
# 解决输入目标文本的空行导致报错的问题
if len(text.strip()) == 0:
continue
if not re.sub("\W+", "", text):
if not re.sub(r"\W+", "", text):
# 检测一下,如果是纯符号,就跳过。
continue
if text[-1] not in splits:
@ -168,7 +179,8 @@ class TextPreprocessor:
def segment_and_extract_feature_for_text(
self, text: str, language: str, version: str = "v1", profile: Dict | None = None
) -> Tuple[list, torch.Tensor, str]:
return self.get_phones_and_bert(text, language, version, profile=profile)
prepared_segments = self.preprocess_text_segments(text, language, version)
return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile)
def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]:
textlist = []
@ -223,24 +235,49 @@ class TextPreprocessor:
def get_phones_and_bert(
self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None
):
text = re.sub(r' {2,}', ' ', text)
textlist, langlist = self._split_text_by_language(text, language)
phones_list = []
bert_list = []
norm_text_list = []
for segment_text, segment_lang in zip(textlist, langlist):
phones, word2ph, norm_text = self.clean_text_inf(segment_text, segment_lang, version)
bert = self.get_bert_inf(phones, word2ph, norm_text, segment_lang, profile=profile)
phones_list.append(phones)
norm_text_list.append(norm_text)
prepared_segments = self.preprocess_text_segments(text, language, version, final=final)
return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile)
def preprocess_text_segments(
self,
text: str,
language: str,
version: str,
final: bool = False,
) -> List[PreparedTextSegment]:
payloads = preprocess_text_segments_payload(text, language, version, final=final)
return [
PreparedTextSegment(
language=str(payload["language"]),
phones=list(payload["phones"]),
word2ph=None if payload["word2ph"] is None else list(payload["word2ph"]),
norm_text=str(payload["norm_text"]),
)
for payload in payloads
]
def build_phones_and_bert_from_segments(
self,
prepared_segments: List[PreparedTextSegment],
profile: Dict | None = None,
) -> Tuple[list, torch.Tensor, str]:
phones_list: List[List[int]] = []
bert_list: List[torch.Tensor] = []
norm_text_list: List[str] = []
for segment in prepared_segments:
bert = self.get_bert_inf(
segment.phones,
segment.word2ph,
segment.norm_text,
segment.language,
profile=profile,
)
phones_list.append(segment.phones)
norm_text_list.append(segment.norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)
if not final and len(phones) < 6:
return self.get_phones_and_bert("." + text, language, version, final=True, profile=profile)
return phones, bert, norm_text
def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None:
@ -253,21 +290,41 @@ class TextPreprocessor:
return
profile[key] = float(max(float(profile.get(key, 0.0)), float(value)))
def _merge_bert_worker_profile(self, profile: Dict | None, worker_profile: Dict[str, float]) -> None:
self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0))
self._accumulate_profile(profile, "bert_admission_wait_ms", worker_profile.get("bert_admission_wait_ms", 0.0))
self._accumulate_profile(profile, "bert_queue_wait_ms", worker_profile.get("bert_queue_wait_ms", 0.0))
self._accumulate_profile(
profile,
"bert_batch_collect_wait_ms",
worker_profile.get("bert_batch_collect_wait_ms", 0.0),
)
self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0))
self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0))
self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0))
self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0))
self._update_profile_peak(profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0))
self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0))
self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0))
self._update_profile_peak(
profile,
"bert_pending_depth_on_enqueue_peak",
worker_profile.get("bert_pending_depth_on_enqueue", 0.0),
)
self._update_profile_peak(
profile,
"bert_pending_depth_on_collect_peak",
worker_profile.get("bert_pending_depth_on_collect", 0.0),
)
self._update_profile_peak(profile, "bert_high_pressure_mode_peak", worker_profile.get("bert_high_pressure_mode", 0.0))
if profile is not None:
profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0))
profile["bert_batch_window_ms"] = float(worker_profile.get("bert_batch_window_ms", 0.0))
def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor:
if self.bert_batch_worker is not None:
feature, worker_profile = self.bert_batch_worker.submit(text, word2ph)
self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0))
self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0))
self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0))
self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0))
self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0))
self._update_profile_peak(
profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0)
)
self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0))
self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0))
if profile is not None:
profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0))
self._merge_bert_worker_profile(profile, worker_profile)
return feature
limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0}
@ -310,9 +367,18 @@ class TextPreprocessor:
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str, profile: Dict | None = None):
def get_bert_inf(
self,
phones: list,
word2ph: Optional[list],
norm_text: str,
language: str,
profile: Dict | None = None,
):
language = language.replace("all_", "")
if language == "zh":
if word2ph is None:
raise ValueError("中文文本缺少 word2ph无法提取 BERT 特征")
feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device)
else:
feature = torch.zeros(
@ -322,6 +388,112 @@ class TextPreprocessor:
return feature
async def build_phones_and_bert_from_segments_async(
self,
prepared_segments: List[PreparedTextSegment],
profile: Dict | None = None,
) -> Tuple[list, torch.Tensor, str]:
segment_jobs = self._build_async_segment_jobs(prepared_segments, profile)
pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = []
for segment_index, segment in enumerate(prepared_segments):
if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None:
continue
if segment.word2ph is None:
raise ValueError("中文文本缺少 word2ph无法提取 BERT 特征")
pending_items.append(
(
segment_jobs["bert_list"],
segment_index,
profile,
self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph),
)
)
if pending_items:
pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items])
for (bert_list, bert_index, item_profile, _), (feature, worker_profile) in zip(pending_items, pending_results):
self._merge_bert_worker_profile(item_profile, worker_profile)
bert_list[bert_index] = feature.to(self.device)
return self._finalize_async_segment_jobs(segment_jobs)
def _build_async_segment_jobs(
self,
prepared_segments: List[PreparedTextSegment],
profile: Dict | None,
) -> Dict[str, List]:
phones_list: List[List[int]] = []
bert_list: List[torch.Tensor | None] = []
norm_text_list: List[str] = []
for segment in prepared_segments:
phones_list.append(segment.phones)
norm_text_list.append(segment.norm_text)
segment_language = segment.language.replace("all_", "")
if segment_language == "zh" and self.bert_batch_worker is not None:
if segment.word2ph is None:
raise ValueError("中文文本缺少 word2ph无法提取 BERT 特征")
bert_list.append(None)
continue
bert_list.append(
self.get_bert_inf(
segment.phones,
segment.word2ph,
segment.norm_text,
segment.language,
profile=profile,
)
)
return {
"phones_list": phones_list,
"bert_list": bert_list,
"norm_text_list": norm_text_list,
}
@staticmethod
def _finalize_async_segment_jobs(segment_jobs: Dict[str, List]) -> Tuple[list, torch.Tensor, str]:
bert = torch.cat([feature for feature in segment_jobs["bert_list"] if feature is not None], dim=1)
phones = sum(segment_jobs["phones_list"], [])
norm_text = "".join(segment_jobs["norm_text_list"])
return phones, bert, norm_text
async def build_phones_and_bert_pair_from_segments_async(
self,
prompt_segments: List[PreparedTextSegment],
target_segments: List[PreparedTextSegment],
prompt_profile: Dict | None = None,
target_profile: Dict | None = None,
) -> Tuple[Tuple[list, torch.Tensor, str], Tuple[list, torch.Tensor, str]]:
prompt_jobs = self._build_async_segment_jobs(prompt_segments, prompt_profile)
target_jobs = self._build_async_segment_jobs(target_segments, target_profile)
pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = []
for segment_jobs, prepared_segments, profile in (
(prompt_jobs, prompt_segments, prompt_profile),
(target_jobs, target_segments, target_profile),
):
for segment_index, segment in enumerate(prepared_segments):
if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None:
continue
if segment.word2ph is None:
raise ValueError("中文文本缺少 word2ph无法提取 BERT 特征")
pending_items.append(
(
segment_jobs["bert_list"],
segment_index,
profile,
self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph),
)
)
if pending_items:
pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items])
for (bert_list, bert_index, profile, _), (feature, worker_profile) in zip(pending_items, pending_results):
self._merge_bert_worker_profile(profile, worker_profile)
bert_list[bert_index] = feature.to(self.device)
return self._finalize_async_segment_jobs(prompt_jobs), self._finalize_async_segment_jobs(target_jobs)
def filter_text(self, texts):
_text = []
if all(text in [None, " ", "\n", ""] for text in texts):

View File

@ -1,3 +1,4 @@
import asyncio
import threading
import time
import uuid
@ -14,7 +15,12 @@ class BertFeatureTask:
word2ph: List[int]
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
created_at: float = field(default_factory=time.perf_counter)
enqueued_at: float = 0.0
admission_wait_ms: float = 0.0
pending_depth_on_enqueue: int = 0
done_event: threading.Event = field(default_factory=threading.Event)
done_loop: asyncio.AbstractEventLoop | None = None
done_future: asyncio.Future | None = None
result_feature: torch.Tensor | None = None
error: Exception | None = None
profile: Dict[str, float] = field(default_factory=dict)
@ -30,14 +36,37 @@ class PrepareBertBatchWorker:
batch_window_ms: int = 5,
max_batch_items: int = 16,
max_batch_tokens: int = 4096,
max_pending_tasks: int = 0,
admission_poll_ms: int = 1,
high_pressure_pending_threshold: int = 0,
high_pressure_batch_window_ms: int | None = None,
high_pressure_max_batch_items: int | None = None,
high_pressure_max_batch_tokens: int | None = None,
):
self.bert_model = bert_model
self.tokenizer = tokenizer
self.device = device
self.stage_limiter = stage_limiter
self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0)
self.batch_window_ms = max(0, int(batch_window_ms))
self.batch_window_s = float(self.batch_window_ms) / 1000.0
self.max_batch_items = max(1, int(max_batch_items))
self.max_batch_tokens = max(16, int(max_batch_tokens))
self.max_pending_tasks = max(0, int(max_pending_tasks))
self.admission_poll_s = max(0.0005, float(max(1, int(admission_poll_ms))) / 1000.0)
self.high_pressure_pending_threshold = max(
0,
int(high_pressure_pending_threshold)
if int(high_pressure_pending_threshold) > 0
else max(self.max_batch_items * 2, 32),
)
hp_window_ms = self.batch_window_ms if high_pressure_batch_window_ms is None else int(high_pressure_batch_window_ms)
hp_items = self.max_batch_items if high_pressure_max_batch_items is None else int(high_pressure_max_batch_items)
hp_tokens = self.max_batch_tokens if high_pressure_max_batch_tokens is None else int(high_pressure_max_batch_tokens)
self.high_pressure_batch_window_ms = max(0, hp_window_ms)
self.high_pressure_batch_window_s = float(self.high_pressure_batch_window_ms) / 1000.0
self.high_pressure_max_batch_items = max(self.max_batch_items, hp_items)
self.high_pressure_max_batch_tokens = max(self.max_batch_tokens, hp_tokens)
self.condition = threading.Condition()
self.pending_tasks: Deque[BertFeatureTask] = deque()
@ -47,26 +76,70 @@ class PrepareBertBatchWorker:
self.total_batches = 0
self.active_batch_size = 0
self.active_batch_peak = 0
self.active_batch_tokens = 0
self.active_batch_tokens_peak = 0
self.high_pressure_batches = 0
self.admission_wait_total_ms = 0.0
self.admission_wait_peak_ms = 0.0
self.worker_thread = threading.Thread(target=self._run_loop, name="prepare-bert-batch-worker", daemon=True)
self.worker_thread.start()
def _estimate_task_tokens(self, task: BertFeatureTask) -> int:
return max(1, len(task.norm_text) + 2)
def _can_enqueue_locked(self) -> bool:
if self.max_pending_tasks <= 0:
return True
return (len(self.pending_tasks) + self.active_batch_size) < self.max_pending_tasks
def _record_enqueue_locked(self, task: BertFeatureTask, admission_wait_ms: float) -> None:
task.admission_wait_ms = float(max(0.0, admission_wait_ms))
task.enqueued_at = time.perf_counter()
task.pending_depth_on_enqueue = int(len(self.pending_tasks))
self.pending_tasks.append(task)
self.total_submitted += 1
self.admission_wait_total_ms += task.admission_wait_ms
self.admission_wait_peak_ms = max(self.admission_wait_peak_ms, task.admission_wait_ms)
if len(self.pending_tasks) > self.pending_peak:
self.pending_peak = len(self.pending_tasks)
self.condition.notify_all()
def _enqueue_task(self, task: BertFeatureTask) -> None:
admission_started = time.perf_counter()
with self.condition:
while not self._can_enqueue_locked():
self.condition.wait(timeout=self.admission_poll_s)
self._record_enqueue_locked(task, (time.perf_counter() - admission_started) * 1000.0)
async def _enqueue_task_async(self, task: BertFeatureTask) -> None:
admission_started = time.perf_counter()
while True:
with self.condition:
if self._can_enqueue_locked():
self._record_enqueue_locked(task, (time.perf_counter() - admission_started) * 1000.0)
return
await asyncio.sleep(self.admission_poll_s)
def submit(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]:
task = BertFeatureTask(norm_text=str(norm_text), word2ph=list(word2ph))
with self.condition:
self.pending_tasks.append(task)
self.total_submitted += 1
if len(self.pending_tasks) > self.pending_peak:
self.pending_peak = len(self.pending_tasks)
self.condition.notify_all()
self._enqueue_task(task)
task.done_event.wait()
if task.error is not None:
raise task.error
assert task.result_feature is not None
return task.result_feature, dict(task.profile)
async def submit_async(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]:
loop = asyncio.get_running_loop()
task = BertFeatureTask(
norm_text=str(norm_text),
word2ph=list(word2ph),
done_loop=loop,
done_future=loop.create_future(),
)
await self._enqueue_task_async(task)
return await task.done_future
def snapshot(self) -> Dict[str, int]:
with self.condition:
return {
@ -77,21 +150,57 @@ class PrepareBertBatchWorker:
"total_batches": self.total_batches,
"active_batch_size": self.active_batch_size,
"active_batch_peak": self.active_batch_peak,
"active_batch_tokens": self.active_batch_tokens,
"active_batch_tokens_peak": self.active_batch_tokens_peak,
"batch_window_ms": int(self.batch_window_s * 1000.0),
"max_batch_items": self.max_batch_items,
"max_batch_tokens": self.max_batch_tokens,
"max_pending_tasks": self.max_pending_tasks,
"high_pressure_pending_threshold": self.high_pressure_pending_threshold,
"high_pressure_batch_window_ms": self.high_pressure_batch_window_ms,
"high_pressure_max_batch_items": self.high_pressure_max_batch_items,
"high_pressure_max_batch_tokens": self.high_pressure_max_batch_tokens,
"high_pressure_batches": self.high_pressure_batches,
"admission_wait_total_ms": self.admission_wait_total_ms,
"admission_wait_peak_ms": self.admission_wait_peak_ms,
}
def _collect_batch(self) -> List[BertFeatureTask]:
def _select_batch_policy_locked(self) -> Tuple[float, int, int, bool, int]:
pending_depth = len(self.pending_tasks)
use_high_pressure = (
self.high_pressure_pending_threshold > 0
and pending_depth >= self.high_pressure_pending_threshold
)
if use_high_pressure:
return (
self.high_pressure_batch_window_s,
self.high_pressure_max_batch_items,
self.high_pressure_max_batch_tokens,
True,
pending_depth,
)
return (
self.batch_window_s,
self.max_batch_items,
self.max_batch_tokens,
False,
pending_depth,
)
def _collect_batch(self) -> Tuple[List[BertFeatureTask], Dict[str, float]]:
with self.condition:
while not self.pending_tasks:
self.condition.wait()
collect_started = time.perf_counter()
batch_window_s, max_batch_items, max_batch_tokens, use_high_pressure, pending_depth_on_collect = (
self._select_batch_policy_locked()
)
batch: List[BertFeatureTask] = [self.pending_tasks.popleft()]
batch_tokens = self._estimate_task_tokens(batch[0])
deadline = time.perf_counter() + self.batch_window_s
deadline = time.perf_counter() + batch_window_s
while len(batch) < self.max_batch_items:
while len(batch) < max_batch_items:
remaining = deadline - time.perf_counter()
if remaining <= 0:
break
@ -100,26 +209,39 @@ class PrepareBertBatchWorker:
continue
next_task = self.pending_tasks[0]
next_tokens = self._estimate_task_tokens(next_task)
if len(batch) >= self.max_batch_items or (batch_tokens + next_tokens) > self.max_batch_tokens:
if len(batch) >= max_batch_items or (batch_tokens + next_tokens) > max_batch_tokens:
break
batch.append(self.pending_tasks.popleft())
batch_tokens += next_tokens
self.active_batch_size = len(batch)
self.active_batch_tokens = batch_tokens
if self.active_batch_size > self.active_batch_peak:
self.active_batch_peak = self.active_batch_size
return batch
if self.active_batch_tokens > self.active_batch_tokens_peak:
self.active_batch_tokens_peak = self.active_batch_tokens
if use_high_pressure:
self.high_pressure_batches += 1
return batch, {
"collect_wait_ms": (time.perf_counter() - collect_started) * 1000.0,
"batch_tokens": float(batch_tokens),
"pending_depth_on_collect": float(pending_depth_on_collect),
"high_pressure_mode": 1.0 if use_high_pressure else 0.0,
"batch_window_ms": float(self.high_pressure_batch_window_ms if use_high_pressure else self.batch_window_ms),
}
def _finalize_batch(self, batch: List[BertFeatureTask]) -> None:
with self.condition:
self.active_batch_size = 0
self.active_batch_tokens = 0
self.total_batches += 1
self.total_finished += len(batch)
self.condition.notify_all()
def _run_batch(self, batch: List[BertFeatureTask]) -> None:
def _run_batch(self, batch: List[BertFeatureTask], batch_meta: Dict[str, float]) -> None:
batch_started = time.perf_counter()
texts = [task.norm_text for task in batch]
batch_tokens = sum(self._estimate_task_tokens(task) for task in batch)
batch_tokens = int(batch_meta["batch_tokens"])
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
if self.stage_limiter is None:
@ -167,6 +289,9 @@ class PrepareBertBatchWorker:
task.result_feature = torch.cat(phone_level_feature, dim=0).T
task.profile = {
"bert_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]),
"bert_admission_wait_ms": float(task.admission_wait_ms),
"bert_queue_wait_ms": max(0.0, (batch_started - task.enqueued_at) * 1000.0),
"bert_batch_collect_wait_ms": float(batch_meta["collect_wait_ms"]),
"bert_forward_ms": float(forward_ms),
"bert_tokenize_ms": float(tokenize_ms),
"bert_scatter_ms": 0.0,
@ -175,6 +300,10 @@ class PrepareBertBatchWorker:
"bert_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
"bert_batch_size": float(len(batch)),
"bert_batch_tokens": float(batch_tokens),
"bert_pending_depth_on_enqueue": float(task.pending_depth_on_enqueue),
"bert_pending_depth_on_collect": float(batch_meta["pending_depth_on_collect"]),
"bert_high_pressure_mode": float(batch_meta["high_pressure_mode"]),
"bert_batch_window_ms": float(batch_meta["batch_window_ms"]),
}
except Exception as exc: # noqa: PERF203
task.error = exc
@ -183,15 +312,35 @@ class PrepareBertBatchWorker:
if task.result_feature is not None:
task.profile["bert_scatter_ms"] = float(scatter_ms)
task.done_event.set()
self._notify_done_future(task)
@staticmethod
def _resolve_done_future(task: BertFeatureTask) -> None:
if task.done_future is None or task.done_future.done():
return
if task.error is not None:
task.done_future.set_exception(task.error)
return
assert task.result_feature is not None
task.done_future.set_result((task.result_feature, dict(task.profile)))
def _notify_done_future(self, task: BertFeatureTask) -> None:
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_done_future, task)
except RuntimeError:
pass
def _run_loop(self) -> None:
while True:
batch = self._collect_batch()
batch, batch_meta = self._collect_batch()
try:
self._run_batch(batch)
self._run_batch(batch, batch_meta)
except Exception as exc: # noqa: PERF203
for task in batch:
task.error = exc
task.done_event.set()
self._notify_done_future(task)
finally:
self._finalize_batch(batch)

View File

@ -0,0 +1,294 @@
from __future__ import annotations
import asyncio
import concurrent.futures
import os
import threading
import time
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import (
PreparedTextFeatures,
SchedulerRequestSpec,
T2SRequestState,
build_request_state_from_parts,
normalize_sentence,
)
@dataclass
class ProfiledResult:
result: Any
submit_at: float
started_at: float
finished_at: float
@property
def queue_ms(self) -> float:
return max(0.0, (self.started_at - self.submit_at) * 1000.0)
@property
def run_ms(self) -> float:
return max(0.0, (self.finished_at - self.started_at) * 1000.0)
class PrepareCoordinator:
def __init__(self, tts: Any):
self.tts = tts
self.lock = threading.Lock()
self.inflight = 0
self.peak_inflight = 0
self.use_async_text_feature_path = bool(
getattr(tts, "prepare_bert_batch_worker", None) is not None
and os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_DIRECT", "0") != "0"
)
self.max_inflight = max(0, int(os.environ.get("GPTSOVITS_PREPARE_MAX_INFLIGHT", "0")))
self._inflight_semaphore = asyncio.Semaphore(self.max_inflight) if self.max_inflight > 0 else None
self.text_feature_workers = 0
self.text_feature_executor = None
if not self.use_async_text_feature_path:
text_feature_default_workers = max(1, int(getattr(tts, "prepare_text_cpu_workers", 16) or 16))
self.text_feature_workers = max(
1,
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_WORKERS", str(text_feature_default_workers))),
)
self.text_feature_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=self.text_feature_workers,
thread_name_prefix="prepare-text-feature",
)
ref_audio_default_workers = max(1, int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4")))
self.ref_audio_workers = max(
1,
int(os.environ.get("GPTSOVITS_PREPARE_REF_ASYNC_WORKERS", str(ref_audio_default_workers))),
)
self.ref_audio_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=self.ref_audio_workers,
thread_name_prefix="prepare-ref-audio",
)
def _mark_enter(self) -> Tuple[int, int]:
with self.lock:
self.inflight += 1
current_inflight = self.inflight
if current_inflight > self.peak_inflight:
self.peak_inflight = current_inflight
return current_inflight, self.peak_inflight
def _mark_leave(self) -> None:
with self.lock:
self.inflight = max(0, self.inflight - 1)
def snapshot(self) -> Dict[str, int]:
with self.lock:
return {
"inflight": int(self.inflight),
"peak_inflight": int(self.peak_inflight),
"max_inflight": int(self.max_inflight),
"text_feature_workers": int(self.text_feature_workers),
"ref_audio_workers": int(self.ref_audio_workers),
}
@staticmethod
def _run_profiled(fn, submit_at: float, *args) -> ProfiledResult:
started_at = time.perf_counter()
result = fn(*args)
finished_at = time.perf_counter()
return ProfiledResult(
result=result,
submit_at=float(submit_at),
started_at=float(started_at),
finished_at=float(finished_at),
)
def _prepare_text_cpu(self, text: str, language: str):
return self.tts.prepare_text_segments(text, language)
def _build_text_features(self, prepared_segments, language: str, cpu_run_ms: float) -> PreparedTextFeatures:
profile: Dict[str, float] = {"cpu_preprocess_ms": float(cpu_run_ms)}
branch_start = time.perf_counter()
phones, bert_features, norm_text = self.tts.build_text_features_from_segments(prepared_segments, profile=profile)
total_ms = float(cpu_run_ms + (time.perf_counter() - branch_start) * 1000.0)
profile["bert_total_ms"] = max(0.0, total_ms - float(cpu_run_ms))
return PreparedTextFeatures(
phones=phones,
bert_features=bert_features,
norm_text=norm_text,
profile=profile,
total_ms=total_ms,
cpu_preprocess_ms=float(cpu_run_ms),
)
async def _run_on_executor(self, executor, fn, *args) -> ProfiledResult:
loop = asyncio.get_running_loop()
submit_at = time.perf_counter()
return await loop.run_in_executor(executor, self._run_profiled, fn, float(submit_at), *args)
async def _run_text_cpu_stage(self, text: str, language: str) -> ProfiledResult:
executor = getattr(self.tts, "prepare_text_cpu_executor", None)
if executor is None:
submit_at = time.perf_counter()
return self._run_profiled(self._prepare_text_cpu, submit_at, text, language)
return await self._run_on_executor(executor, self._prepare_text_cpu, text, language)
async def _run_text_feature_stage(self, prepared_segments, language: str, cpu_run_ms: float) -> ProfiledResult:
return await self._run_on_executor(self.text_feature_executor, self._build_text_features, prepared_segments, language, cpu_run_ms)
@staticmethod
def _estimate_text_feature_run_ms(profile: Dict[str, float]) -> float:
return float(
profile.get("bert_wait_ms", 0.0)
+ profile.get("bert_tokenize_ms", 0.0)
+ profile.get("bert_forward_ms", 0.0)
+ profile.get("bert_scatter_ms", 0.0)
)
async def _run_text_feature_pair_stage(
self,
prompt_segments,
target_segments,
prompt_cpu_run_ms: float,
target_cpu_run_ms: float,
) -> tuple[ProfiledResult, ProfiledResult]:
if self.text_feature_executor is not None:
prompt_feature_task = asyncio.create_task(
self._run_text_feature_stage(prompt_segments, None, prompt_cpu_run_ms)
)
target_feature_task = asyncio.create_task(
self._run_text_feature_stage(target_segments, None, target_cpu_run_ms)
)
return await asyncio.gather(prompt_feature_task, target_feature_task)
prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)}
target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)}
submit_at = time.perf_counter()
started_at = float(submit_at)
prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async(
prompt_segments,
target_segments,
prompt_profile=prompt_profile,
target_profile=target_profile,
)
finished_at = time.perf_counter()
prompt_result = PreparedTextFeatures(
phones=prompt_result_raw[0],
bert_features=prompt_result_raw[1],
norm_text=prompt_result_raw[2],
profile=prompt_profile,
total_ms=float(prompt_cpu_run_ms + self._estimate_text_feature_run_ms(prompt_profile)),
cpu_preprocess_ms=float(prompt_cpu_run_ms),
)
target_result = PreparedTextFeatures(
phones=target_result_raw[0],
bert_features=target_result_raw[1],
norm_text=target_result_raw[2],
profile=target_profile,
total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)),
cpu_preprocess_ms=float(target_cpu_run_ms),
)
prompt_profiled = ProfiledResult(
result=prompt_result,
submit_at=float(submit_at),
started_at=started_at,
finished_at=float(submit_at + self._estimate_text_feature_run_ms(prompt_profile) / 1000.0),
)
target_profiled = ProfiledResult(
result=target_result,
submit_at=float(submit_at),
started_at=started_at,
finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0),
)
if finished_at > prompt_profiled.finished_at:
prompt_result.profile["bert_total_ms"] = max(
self._estimate_text_feature_run_ms(prompt_profile),
(finished_at - submit_at) * 1000.0,
)
target_result.profile["bert_total_ms"] = max(
self._estimate_text_feature_run_ms(target_profile),
(finished_at - submit_at) * 1000.0,
)
else:
prompt_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(prompt_profile)
target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile)
return prompt_profiled, target_profiled
async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult:
return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path)
async def prepare_state_profiled_async(
self,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
) -> tuple[T2SRequestState, float, float]:
admission_start = time.perf_counter()
if self._inflight_semaphore is not None:
await self._inflight_semaphore.acquire()
prepare_admission_wait_ms = max(0.0, (time.perf_counter() - admission_start) * 1000.0)
current_inflight, peak_inflight = self._mark_enter()
prepare_start = time.perf_counter()
prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang)
text = spec.text.strip("\n")
try:
text_pair_start = time.perf_counter()
prompt_cpu_task = asyncio.create_task(self._run_text_cpu_stage(prompt_text, spec.prompt_lang))
target_cpu_task = asyncio.create_task(self._run_text_cpu_stage(text, spec.text_lang))
ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(spec.ref_audio_path)))
prompt_cpu_profiled, target_cpu_profiled = await asyncio.gather(prompt_cpu_task, target_cpu_task)
text_feature_pair_task = asyncio.create_task(
self._run_text_feature_pair_stage(
prompt_cpu_profiled.result,
target_cpu_profiled.result,
prompt_cpu_profiled.run_ms,
target_cpu_profiled.run_ms,
)
)
(prompt_feature_profiled, target_feature_profiled), ref_audio_profiled = await asyncio.gather(
text_feature_pair_task,
ref_audio_task,
)
text_pair_end = time.perf_counter()
state = build_request_state_from_parts(
tts=self.tts,
spec=spec,
prompt_text=prompt_text,
text=text,
prompt_result=prompt_feature_profiled.result,
target_result=target_feature_profiled.result,
ref_audio_bundle=ref_audio_profiled.result,
prepare_start=prepare_start,
prepare_sync_start=prepare_start,
profile_overrides={
"executor_queue_ms": max(0.0, (prepare_start - prepare_submit_at) * 1000.0),
"prepare_admission_wait_ms": prepare_admission_wait_ms,
"executor_run_wall_ms": max(0.0, (time.perf_counter() - prepare_start) * 1000.0),
"text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0),
"prompt_text_parallel_future_wait_ms": 0.0,
"prompt_text_parallel_future_executor_queue_ms": 0.0,
"prompt_text_parallel_future_run_ms": 0.0,
"prompt_text_parallel_future_finish_after_submit_ms": 0.0,
"prompt_text_parallel_future_queue_tail_after_target_ms": 0.0,
"prompt_text_parallel_future_run_tail_after_target_ms": 0.0,
"prompt_text_cpu_queue_ms": prompt_cpu_profiled.queue_ms,
"prompt_text_cpu_run_ms": prompt_cpu_profiled.run_ms,
"prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms,
"prompt_text_feature_run_ms": prompt_feature_profiled.run_ms,
"text_cpu_queue_ms": target_cpu_profiled.queue_ms,
"text_cpu_run_ms": target_cpu_profiled.run_ms,
"text_feature_queue_ms": target_feature_profiled.queue_ms,
"text_feature_run_ms": target_feature_profiled.run_ms,
"ref_audio_task_queue_ms": ref_audio_profiled.queue_ms,
"ref_audio_task_run_ms": ref_audio_profiled.run_ms,
"worker_prepare_inflight_on_enter": float(current_inflight),
"worker_prepare_peak_inflight": float(peak_inflight),
},
)
prepare_exec_finished_at = time.perf_counter()
state.prepare_profile["executor_run_wall_ms"] = max(
0.0, (prepare_exec_finished_at - prepare_start) * 1000.0
)
return state, prepare_start, prepare_exec_finished_at
finally:
self._mark_leave()
if self._inflight_semaphore is not None:
self._inflight_semaphore.release()

View File

@ -1,6 +1,5 @@
from __future__ import annotations
from concurrent.futures import Future
from dataclasses import dataclass
from pathlib import Path
import time
@ -89,20 +88,31 @@ class T2SFinishedItem:
class T2SActiveBatch:
request_ids: List[str]
states: List[T2SRequestState]
x: torch.Tensor
x_lens: torch.LongTensor
x: Optional[torch.Tensor]
x_lens: Optional[torch.LongTensor]
y_sequences: List[torch.LongTensor]
prefix_lens: torch.LongTensor
xy_pos: torch.Tensor
key_padding_mask: torch.Tensor
prefill_attn_mask: torch.Tensor
key_padding_mask: Optional[torch.Tensor]
prefill_attn_mask: Optional[torch.Tensor]
decode_attn_mask: Optional[torch.Tensor]
k_cache: Optional[List[torch.Tensor]]
v_cache: Optional[List[torch.Tensor]]
step_idx: int
kv_lens: Optional[torch.LongTensor]
step_indices: torch.LongTensor
prefill_done: bool
@dataclass
class PreparedTextFeatures:
phones: List[int]
bert_features: torch.Tensor
norm_text: str
profile: Dict[str, float]
total_ms: float
cpu_preprocess_ms: float
def normalize_sentence(text: str, language: str) -> str:
text = text.strip("\n").strip()
if not text:
@ -113,105 +123,125 @@ def normalize_sentence(text: str, language: str) -> str:
@torch.inference_mode()
def prepare_request_state(
def prepare_text_features(
tts: Any,
text: str,
language: str,
) -> PreparedTextFeatures:
device = tts.configs.device
profile: Dict[str, float] = {}
branch_start = time.perf_counter()
_sync_device(device)
cpu_start = time.perf_counter()
prepared_segments = tts.prepare_text_segments(text, language)
_sync_device(device)
cpu_preprocess_ms = (time.perf_counter() - cpu_start) * 1000.0
profile["cpu_preprocess_ms"] = float(cpu_preprocess_ms)
bert_start = time.perf_counter()
phones, bert_features, norm_text = tts.build_text_features_from_segments(prepared_segments, profile=profile)
_sync_device(device)
profile["bert_total_ms"] = (time.perf_counter() - bert_start) * 1000.0
total_ms = (time.perf_counter() - branch_start) * 1000.0
return PreparedTextFeatures(
phones=phones,
bert_features=bert_features,
norm_text=norm_text,
profile=profile,
total_ms=float(total_ms),
cpu_preprocess_ms=float(cpu_preprocess_ms),
)
@torch.inference_mode()
def build_request_state_from_parts(
tts: Any,
spec: SchedulerRequestSpec,
prompt_text: str,
text: str,
prompt_result: PreparedTextFeatures,
target_result: PreparedTextFeatures,
ref_audio_bundle: Dict[str, Any],
prepare_start: float,
prepare_sync_start: float,
profile_overrides: Optional[Dict[str, float]] = None,
) -> T2SRequestState:
device = tts.configs.device
prepare_start = time.perf_counter()
_sync_device(device)
prepare_sync_start = time.perf_counter()
prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang)
text = spec.text.strip("\n")
prompt_text_profile: Dict[str, float] = {}
text_features_profile: Dict[str, float] = {}
text_feature_pair_start = time.perf_counter()
prompt_future: Future | None = None
def _extract_prompt_features():
_sync_device(device)
prompt_start = time.perf_counter()
result = tts.extract_text_features(prompt_text, spec.prompt_lang, profile=prompt_text_profile)
_sync_device(device)
return result, (time.perf_counter() - prompt_start) * 1000.0
if getattr(tts, "prepare_text_cpu_executor", None) is not None:
prompt_future = tts.prepare_text_cpu_executor.submit(_extract_prompt_features)
_sync_device(device)
text_features_start = time.perf_counter()
phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang, profile=text_features_profile)
_sync_device(device)
text_features_ms = (time.perf_counter() - text_features_start) * 1000.0
if prompt_future is None:
_sync_device(device)
prompt_text_features_start = time.perf_counter()
prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features(
prompt_text, spec.prompt_lang, profile=prompt_text_profile
)
_sync_device(device)
prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0
prompt_text_profile["parallel_future_wait_ms"] = 0.0
else:
prompt_wait_start = time.perf_counter()
(prompt_phones, prompt_bert_features, prompt_norm_text), prompt_text_features_ms = prompt_future.result()
prompt_text_profile["parallel_future_wait_ms"] = (time.perf_counter() - prompt_wait_start) * 1000.0
text_feature_pair_ms = (time.perf_counter() - text_feature_pair_start) * 1000.0
if phones is None:
raise ValueError(f"{spec.request_id} text preprocessing returned no phones")
_sync_device(device)
ref_audio_bundle_start = time.perf_counter()
ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path))
ref_audio_bundle_ms = float(ref_audio_bundle.get("profile", {}).get("bundle_total_ms", 0.0))
bundle_profile = ref_audio_bundle.get("profile", {})
prompt_semantic = ref_audio_bundle["prompt_semantic"].long()
spec_audio, audio_16k = ref_audio_bundle["refer_spec"]
raw_audio = ref_audio_bundle["raw_audio"]
raw_sr = int(ref_audio_bundle["raw_sr"])
_sync_device(device)
ref_audio_bundle_ms = (time.perf_counter() - ref_audio_bundle_start) * 1000.0
bundle_profile = ref_audio_bundle.get("profile", {})
prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms))
ref_spec_ms = float(bundle_profile.get("ref_spec_ms", 0.0))
audio_load_ms = float(bundle_profile.get("audio_load_ms", 0.0))
_sync_device(device)
tensorize_start = time.perf_counter()
phones_tensor = torch.LongTensor(phones).to(tts.configs.device)
prompt_phones_tensor = torch.LongTensor(prompt_phones).to(tts.configs.device)
all_phones = torch.LongTensor(prompt_phones + phones).to(tts.configs.device)
all_bert_features = torch.cat([prompt_bert_features, bert_features], dim=1).to(
phones_tensor = torch.LongTensor(target_result.phones).to(tts.configs.device)
prompt_phones_tensor = torch.LongTensor(prompt_result.phones).to(tts.configs.device)
all_phones = torch.LongTensor(prompt_result.phones + target_result.phones).to(tts.configs.device)
all_bert_features = torch.cat([prompt_result.bert_features, target_result.bert_features], dim=1).to(
dtype=tts.precision, device=tts.configs.device
)
_sync_device(device)
tensorize_ms = (time.perf_counter() - tensorize_start) * 1000.0
_sync_device(device)
prepare_profile = {
"prompt_text_features_ms": prompt_text_features_ms,
"text_features_ms": text_features_ms,
"prompt_text_bert_wait_ms": float(prompt_text_profile.get("bert_wait_ms", 0.0)),
"prompt_text_bert_forward_ms": float(prompt_text_profile.get("bert_forward_ms", 0.0)),
"prompt_text_bert_tokenize_ms": float(prompt_text_profile.get("bert_tokenize_ms", 0.0)),
"prompt_text_bert_scatter_ms": float(prompt_text_profile.get("bert_scatter_ms", 0.0)),
"prompt_text_bert_calls": float(prompt_text_profile.get("bert_calls", 0.0)),
"prompt_text_bert_stage_slots": float(prompt_text_profile.get("bert_stage_slots", 0.0)),
"prompt_text_bert_stage_inflight_peak": float(prompt_text_profile.get("bert_stage_inflight_peak", 0.0)),
"prompt_text_bert_batch_size_peak": float(prompt_text_profile.get("bert_batch_size_peak", 0.0)),
"prompt_text_bert_batch_tokens_peak": float(prompt_text_profile.get("bert_batch_tokens_peak", 0.0)),
"prompt_text_parallel_future_wait_ms": float(prompt_text_profile.get("parallel_future_wait_ms", 0.0)),
"text_bert_wait_ms": float(text_features_profile.get("bert_wait_ms", 0.0)),
"text_bert_forward_ms": float(text_features_profile.get("bert_forward_ms", 0.0)),
"text_bert_tokenize_ms": float(text_features_profile.get("bert_tokenize_ms", 0.0)),
"text_bert_scatter_ms": float(text_features_profile.get("bert_scatter_ms", 0.0)),
"text_bert_calls": float(text_features_profile.get("bert_calls", 0.0)),
"text_bert_stage_slots": float(text_features_profile.get("bert_stage_slots", 0.0)),
"text_bert_stage_inflight_peak": float(text_features_profile.get("bert_stage_inflight_peak", 0.0)),
"text_bert_batch_size_peak": float(text_features_profile.get("bert_batch_size_peak", 0.0)),
"text_bert_batch_tokens_peak": float(text_features_profile.get("bert_batch_tokens_peak", 0.0)),
"text_feature_pair_ms": text_feature_pair_ms,
"prompt_text_features_ms": float(prompt_result.total_ms),
"text_features_ms": float(target_result.total_ms),
"prompt_text_cpu_preprocess_ms": float(prompt_result.cpu_preprocess_ms),
"text_cpu_preprocess_ms": float(target_result.cpu_preprocess_ms),
"prompt_text_bert_wait_ms": float(prompt_result.profile.get("bert_wait_ms", 0.0)),
"prompt_text_bert_admission_wait_ms": float(prompt_result.profile.get("bert_admission_wait_ms", 0.0)),
"prompt_text_bert_queue_wait_ms": float(prompt_result.profile.get("bert_queue_wait_ms", 0.0)),
"prompt_text_bert_batch_collect_wait_ms": float(prompt_result.profile.get("bert_batch_collect_wait_ms", 0.0)),
"prompt_text_bert_forward_ms": float(prompt_result.profile.get("bert_forward_ms", 0.0)),
"prompt_text_bert_tokenize_ms": float(prompt_result.profile.get("bert_tokenize_ms", 0.0)),
"prompt_text_bert_scatter_ms": float(prompt_result.profile.get("bert_scatter_ms", 0.0)),
"prompt_text_bert_calls": float(prompt_result.profile.get("bert_calls", 0.0)),
"prompt_text_bert_stage_slots": float(prompt_result.profile.get("bert_stage_slots", 0.0)),
"prompt_text_bert_stage_inflight_peak": float(prompt_result.profile.get("bert_stage_inflight_peak", 0.0)),
"prompt_text_bert_batch_size_peak": float(prompt_result.profile.get("bert_batch_size_peak", 0.0)),
"prompt_text_bert_batch_tokens_peak": float(prompt_result.profile.get("bert_batch_tokens_peak", 0.0)),
"prompt_text_bert_pending_depth_on_enqueue_peak": float(
prompt_result.profile.get("bert_pending_depth_on_enqueue_peak", 0.0)
),
"prompt_text_bert_pending_depth_on_collect_peak": float(
prompt_result.profile.get("bert_pending_depth_on_collect_peak", 0.0)
),
"prompt_text_bert_high_pressure_mode_peak": float(
prompt_result.profile.get("bert_high_pressure_mode_peak", 0.0)
),
"prompt_text_bert_batch_window_ms": float(prompt_result.profile.get("bert_batch_window_ms", 0.0)),
"prompt_text_parallel_future_wait_ms": 0.0,
"prompt_text_parallel_future_executor_queue_ms": 0.0,
"prompt_text_parallel_future_run_ms": float(prompt_result.total_ms),
"prompt_text_parallel_future_finish_after_submit_ms": float(prompt_result.total_ms),
"prompt_text_parallel_future_queue_tail_after_target_ms": 0.0,
"prompt_text_parallel_future_run_tail_after_target_ms": 0.0,
"text_bert_wait_ms": float(target_result.profile.get("bert_wait_ms", 0.0)),
"text_bert_admission_wait_ms": float(target_result.profile.get("bert_admission_wait_ms", 0.0)),
"text_bert_queue_wait_ms": float(target_result.profile.get("bert_queue_wait_ms", 0.0)),
"text_bert_batch_collect_wait_ms": float(target_result.profile.get("bert_batch_collect_wait_ms", 0.0)),
"text_bert_forward_ms": float(target_result.profile.get("bert_forward_ms", 0.0)),
"text_bert_tokenize_ms": float(target_result.profile.get("bert_tokenize_ms", 0.0)),
"text_bert_scatter_ms": float(target_result.profile.get("bert_scatter_ms", 0.0)),
"text_bert_calls": float(target_result.profile.get("bert_calls", 0.0)),
"text_bert_stage_slots": float(target_result.profile.get("bert_stage_slots", 0.0)),
"text_bert_stage_inflight_peak": float(target_result.profile.get("bert_stage_inflight_peak", 0.0)),
"text_bert_batch_size_peak": float(target_result.profile.get("bert_batch_size_peak", 0.0)),
"text_bert_batch_tokens_peak": float(target_result.profile.get("bert_batch_tokens_peak", 0.0)),
"text_bert_pending_depth_on_enqueue_peak": float(
target_result.profile.get("bert_pending_depth_on_enqueue_peak", 0.0)
),
"text_bert_pending_depth_on_collect_peak": float(
target_result.profile.get("bert_pending_depth_on_collect_peak", 0.0)
),
"text_bert_high_pressure_mode_peak": float(target_result.profile.get("bert_high_pressure_mode_peak", 0.0)),
"text_bert_batch_window_ms": float(target_result.profile.get("bert_batch_window_ms", 0.0)),
"text_feature_pair_ms": float(max(prompt_result.total_ms, target_result.total_ms)),
"text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)),
"audio_load_ms": audio_load_ms,
"audio_stage_wait_ms": float(bundle_profile.get("audio_stage_wait_ms", 0.0)),
@ -233,6 +263,8 @@ def prepare_request_state(
"total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0,
"wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0,
}
if profile_overrides:
prepare_profile.update({key: float(value) for key, value in profile_overrides.items()})
return T2SRequestState(
request_id=spec.request_id,
ref_audio_path=spec.ref_audio_path,
@ -240,8 +272,8 @@ def prepare_request_state(
prompt_lang=spec.prompt_lang,
text=text,
text_lang=spec.text_lang,
norm_prompt_text=prompt_norm_text,
norm_text=norm_text,
norm_prompt_text=prompt_result.norm_text,
norm_text=target_result.norm_text,
phones=phones_tensor,
prompt_phones=prompt_phones_tensor,
all_phones=all_phones,
@ -260,6 +292,33 @@ def prepare_request_state(
)
@torch.inference_mode()
def prepare_request_state(
tts: Any,
spec: SchedulerRequestSpec,
) -> T2SRequestState:
prepare_start = time.perf_counter()
prepare_sync_start = time.perf_counter()
prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang)
text = spec.text.strip("\n")
prompt_result = prepare_text_features(tts, prompt_text, spec.prompt_lang)
target_result = prepare_text_features(tts, text, spec.text_lang)
if target_result.phones is None:
raise ValueError(f"{spec.request_id} text preprocessing returned no phones")
ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path))
return build_request_state_from_parts(
tts=tts,
spec=spec,
prompt_text=prompt_text,
text=text,
prompt_result=prompt_result,
target_result=target_result,
ref_audio_bundle=ref_audio_bundle,
prepare_start=prepare_start,
prepare_sync_start=prepare_sync_start,
)
def _left_pad_hidden(hidden: torch.Tensor, target_len: int) -> torch.Tensor:
if hidden.shape[0] >= target_len:
return hidden
@ -417,7 +476,8 @@ def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SAct
decode_attn_mask=None,
k_cache=None,
v_cache=None,
step_idx=0,
kv_lens=None,
step_indices=torch.zeros((len(states),), dtype=torch.long, device=device),
prefill_done=False,
)
@ -433,6 +493,64 @@ def build_next_xy_pos(model: Any, y_sequences: Sequence[torch.LongTensor]) -> to
)
def _compact_cache_to_kv_lens(
cache: torch.Tensor,
kv_lens: torch.LongTensor,
) -> torch.Tensor:
target_len = int(kv_lens.max().item())
if cache.shape[1] == target_len and torch.all(kv_lens == target_len).item():
return cache
compacted = cache.new_zeros((cache.shape[0], target_len, cache.shape[2]))
for batch_index, kv_len in enumerate(kv_lens.tolist()):
if kv_len <= 0:
continue
compacted[batch_index, -kv_len:, :] = cache[batch_index, -kv_len:, :]
return compacted
def _compact_decode_mask_to_kv_lens(
decode_attn_mask: Optional[torch.Tensor],
kv_lens: torch.LongTensor,
) -> Optional[torch.Tensor]:
target_len = int(kv_lens.max().item()) + 1
if decode_attn_mask is None:
return None
if decode_attn_mask.shape[-1] == target_len and torch.all(kv_lens + 1 == target_len).item():
return decode_attn_mask
compacted = torch.ones(
(decode_attn_mask.shape[0], 1, 1, target_len),
dtype=decode_attn_mask.dtype,
device=decode_attn_mask.device,
)
for batch_index, kv_len in enumerate(kv_lens.tolist()):
current_len = kv_len + 1
compacted[batch_index, :, :, -current_len:] = decode_attn_mask[batch_index, :, :, -current_len:]
if not compacted.any().item():
return None
return compacted
def _advance_decode_mask(
decode_attn_mask: Optional[torch.Tensor],
kv_lens: torch.LongTensor,
) -> Optional[torch.Tensor]:
if decode_attn_mask is None:
return None
target_len = int(kv_lens.max().item()) + 2
advanced = torch.zeros(
(decode_attn_mask.shape[0], 1, 1, target_len),
dtype=decode_attn_mask.dtype,
device=decode_attn_mask.device,
)
for batch_index, kv_len in enumerate(kv_lens.tolist()):
current_len = kv_len + 1
next_mask = F.pad(decode_attn_mask[batch_index : batch_index + 1, :, :, -current_len:], (0, 1), value=False)
advanced[batch_index : batch_index + 1, :, :, -next_mask.shape[-1] :] = next_mask
if not advanced.any().item():
return None
return advanced
def _sample_per_request(
model: Any,
active_batch: T2SActiveBatch,
@ -443,16 +561,15 @@ def _sample_per_request(
keep_indices: List[int] = []
updated_sequences: List[torch.LongTensor] = []
step_idx = active_batch.step_idx
sampling_keys = [
_sampling_group_key(
top_k=state.top_k,
top_p=state.top_p,
temperature=state.temperature,
repetition_penalty=state.repetition_penalty,
trim_eos=False,
trim_eos=int(active_batch.step_indices[batch_index].item()) < 11,
)
for state in active_batch.states
for batch_index, state in enumerate(active_batch.states)
]
sampled_items, argmax_tokens = _batched_sample_by_group(
logits=logits,
@ -460,6 +577,7 @@ def _sample_per_request(
sampling_keys=sampling_keys,
)
for batch_index, state in enumerate(active_batch.states):
step_index = int(active_batch.step_indices[batch_index].item())
current_history = active_batch.y_sequences[batch_index]
sampled = sampled_items[batch_index]
sampled_token = int(sampled[0, 0].item())
@ -469,7 +587,7 @@ def _sample_per_request(
finish_reason: Optional[str] = None
if state.early_stop_num != -1 and (new_history.shape[0] - int(active_batch.prefix_lens[batch_index].item())) > state.early_stop_num:
finish_reason = "early_stop"
elif step_idx + 1 >= max_steps:
elif step_index + 1 >= max_steps:
finish_reason = "max_step"
elif sampled_token == model.EOS:
finish_reason = "eos_sample"
@ -482,7 +600,7 @@ def _sample_per_request(
T2SFinishedItem(
request_id=state.request_id,
semantic_tokens=new_history[prefix_len:-1].clone(),
finish_idx=step_idx,
finish_idx=step_index,
finish_reason=finish_reason,
)
)
@ -493,30 +611,48 @@ def _sample_per_request(
return finished_items, keep_indices, updated_sequences
@torch.inference_mode()
def decode_one_step(
model: Any,
active_batch: T2SActiveBatch,
max_steps: int,
) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]:
if not active_batch.prefill_done:
was_prefill = not active_batch.prefill_done
if was_prefill:
if active_batch.prefill_attn_mask is None or active_batch.key_padding_mask is None:
raise ValueError("prefill 阶段缺少必要 mask")
xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt(
active_batch.xy_pos, active_batch.prefill_attn_mask, None
)
active_batch.kv_lens = active_batch.x_lens + active_batch.prefix_lens
active_batch.decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False)
if active_batch.k_cache is None or active_batch.v_cache is None or active_batch.kv_lens is None:
raise ValueError("prefill 阶段未生成完整 KV cache")
active_batch.k_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.k_cache]
active_batch.v_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.v_cache]
active_batch.decode_attn_mask = _compact_decode_mask_to_kv_lens(active_batch.decode_attn_mask, active_batch.kv_lens)
active_batch.x = None
active_batch.x_lens = None
active_batch.key_padding_mask = None
active_batch.prefill_attn_mask = None
active_batch.prefill_done = True
else:
if active_batch.k_cache is None or active_batch.v_cache is None or active_batch.kv_lens is None:
raise ValueError("decode 阶段缺少 KV cache")
batched_decode_attn_mask = None
if active_batch.decode_attn_mask is not None:
batched_decode_attn_mask = _materialize_decode_mask_for_active_batch(active_batch)
if not batched_decode_attn_mask.any().item():
batched_decode_attn_mask = None
xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token(
active_batch.xy_pos,
active_batch.k_cache,
active_batch.v_cache,
active_batch.decode_attn_mask,
batched_decode_attn_mask,
)
if active_batch.decode_attn_mask is not None:
active_batch.decode_attn_mask = F.pad(active_batch.decode_attn_mask, (0, 1), value=False)
active_batch.decode_attn_mask = _advance_decode_mask(active_batch.decode_attn_mask, active_batch.kv_lens)
logits = model.ar_predict_layer(xy_dec[:, -1])
if active_batch.step_idx < 11:
logits = logits[:, :-1]
finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps)
if len(keep_indices) == 0:
@ -528,16 +664,32 @@ def decode_one_step(
active_batch.states = [active_batch.states[i] for i in keep_indices]
active_batch.y_sequences = updated_sequences
active_batch.prefix_lens = torch.index_select(active_batch.prefix_lens, dim=0, index=keep_tensor)
next_step_indices = torch.index_select(active_batch.step_indices, dim=0, index=keep_tensor)
next_kv_lens = None if active_batch.kv_lens is None else torch.index_select(active_batch.kv_lens, dim=0, index=keep_tensor)
active_batch.step_indices = next_step_indices + 1
if not was_prefill:
if next_kv_lens is not None:
active_batch.kv_lens = next_kv_lens + 1
else:
active_batch.kv_lens = next_kv_lens
if active_batch.decode_attn_mask is not None:
active_batch.decode_attn_mask = torch.index_select(active_batch.decode_attn_mask, dim=0, index=keep_tensor)
if not active_batch.decode_attn_mask.any().item():
active_batch.decode_attn_mask = None
if active_batch.k_cache is not None and active_batch.v_cache is not None:
for cache_index in range(len(active_batch.k_cache)):
active_batch.k_cache[cache_index] = torch.index_select(active_batch.k_cache[cache_index], dim=0, index=keep_tensor)
active_batch.v_cache[cache_index] = torch.index_select(active_batch.v_cache[cache_index], dim=0, index=keep_tensor)
if active_batch.kv_lens is not None:
active_batch.k_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.k_cache]
active_batch.v_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.v_cache]
active_batch.decode_attn_mask = _compact_decode_mask_to_kv_lens(
active_batch.decode_attn_mask,
active_batch.kv_lens,
)
active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences)
active_batch.step_idx += 1
return active_batch, finished_items
@ -583,6 +735,126 @@ def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) ->
)
def _materialize_decode_mask_for_active_batch(
active_batch: T2SActiveBatch,
target_mask_len: Optional[int] = None,
) -> torch.Tensor:
if active_batch.k_cache is None or active_batch.kv_lens is None:
raise ValueError("active batch 缺少 KV cache 或 kv_lens")
current_mask_len = active_batch.k_cache[0].shape[1] + 1
if target_mask_len is None:
target_mask_len = current_mask_len
if active_batch.decode_attn_mask is None:
mask = torch.zeros(
(len(active_batch.request_ids), 1, 1, current_mask_len),
dtype=torch.bool,
device=active_batch.k_cache[0].device,
)
else:
rows: List[torch.Tensor] = []
for batch_index, kv_len in enumerate(active_batch.kv_lens.tolist()):
row_len = kv_len + 1
row_mask = _fit_decode_mask_length(
active_batch.decode_attn_mask[batch_index : batch_index + 1],
row_len,
)
rows.append(_pad_decode_mask_left(row_mask, target_mask_len))
mask = torch.cat(rows, dim=0)
if target_mask_len != current_mask_len and active_batch.decode_attn_mask is None:
mask = _pad_decode_mask_left(mask, target_mask_len)
return mask
@torch.inference_mode()
def run_prefill_active_batch(
model: Any,
states: Sequence[T2SRequestState],
max_steps: int,
) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]:
if not states:
return None, []
active_batch = build_prefill_batch(model, states)
return decode_one_step(model, active_batch, max_steps=max_steps)
@torch.inference_mode()
def merge_active_batches(
model: Any,
left_batch: Optional[T2SActiveBatch],
right_batch: Optional[T2SActiveBatch],
) -> Optional[T2SActiveBatch]:
if left_batch is None:
return right_batch
if right_batch is None:
return left_batch
if not left_batch.prefill_done or not right_batch.prefill_done:
raise ValueError("只有 prefill 完成后的 active batch 才能 merge")
if left_batch.k_cache is None or left_batch.v_cache is None or right_batch.k_cache is None or right_batch.v_cache is None:
raise ValueError("merge active batch 时缺少 KV cache")
left_kv_len = int(left_batch.k_cache[0].shape[1])
right_kv_len = int(right_batch.k_cache[0].shape[1])
merged_kv_len = max(left_kv_len, right_kv_len)
merged_mask_len = merged_kv_len + 1
merged_k_cache: List[torch.Tensor] = []
merged_v_cache: List[torch.Tensor] = []
for layer_index in range(len(left_batch.k_cache)):
merged_k_cache.append(
torch.cat(
[
_pad_cache_left(left_batch.k_cache[layer_index], merged_kv_len),
_pad_cache_left(right_batch.k_cache[layer_index], merged_kv_len),
],
dim=0,
)
)
merged_v_cache.append(
torch.cat(
[
_pad_cache_left(left_batch.v_cache[layer_index], merged_kv_len),
_pad_cache_left(right_batch.v_cache[layer_index], merged_kv_len),
],
dim=0,
)
)
merged_decode_attn_mask = torch.cat(
[
_materialize_decode_mask_for_active_batch(left_batch, merged_mask_len),
_materialize_decode_mask_for_active_batch(right_batch, merged_mask_len),
],
dim=0,
)
merged_request_ids = list(left_batch.request_ids) + list(right_batch.request_ids)
merged_states = list(left_batch.states) + list(right_batch.states)
merged_y_sequences = list(left_batch.y_sequences) + list(right_batch.y_sequences)
merged_prefix_lens = torch.cat([left_batch.prefix_lens, right_batch.prefix_lens], dim=0)
if left_batch.kv_lens is None or right_batch.kv_lens is None:
raise ValueError("merge active batch 时缺少 kv_lens")
merged_kv_lens = torch.cat([left_batch.kv_lens, right_batch.kv_lens], dim=0)
merged_decode_attn_mask = _compact_decode_mask_to_kv_lens(merged_decode_attn_mask, merged_kv_lens)
merged_step_indices = torch.cat([left_batch.step_indices, right_batch.step_indices], dim=0)
return T2SActiveBatch(
request_ids=merged_request_ids,
states=merged_states,
x=None,
x_lens=None,
y_sequences=merged_y_sequences,
prefix_lens=merged_prefix_lens,
xy_pos=build_next_xy_pos(model, merged_y_sequences),
key_padding_mask=None,
prefill_attn_mask=None,
decode_attn_mask=merged_decode_attn_mask,
k_cache=merged_k_cache,
v_cache=merged_v_cache,
kv_lens=merged_kv_lens,
step_indices=merged_step_indices,
prefill_done=True,
)
@torch.inference_mode()
def run_prefill_step(
model: Any,
@ -804,29 +1076,24 @@ def run_scheduler_continuous(
max_steps: int,
) -> List[T2SFinishedItem]:
pending = sorted(states, key=lambda item: (item.ready_step, item.request_id))
running_requests: List[T2SRunningRequest] = []
active_batch: Optional[T2SActiveBatch] = None
finished: List[T2SFinishedItem] = []
current_tick = 0
while pending or running_requests:
while pending or active_batch is not None:
admitted: List[T2SRequestState] = []
while pending and pending[0].ready_step <= current_tick:
admitted.append(pending.pop(0))
admitted_running, admitted_finished = run_prefill_step(model, admitted, max_steps=max_steps)
admitted_active_batch, admitted_finished = run_prefill_active_batch(model, admitted, max_steps=max_steps)
finished.extend(admitted_finished)
active_batch = merge_active_batches(model, active_batch, admitted_active_batch)
if running_requests:
running_requests, step_finished = run_decode_step_for_running(
model,
running_requests,
max_steps=max_steps,
)
if active_batch is not None:
active_batch, step_finished = decode_one_step(model, active_batch, max_steps=max_steps)
finished.extend(step_finished)
running_requests.extend(admitted_running)
if not running_requests and pending:
if active_batch is None and pending:
current_tick = max(current_tick + 1, pending[0].ready_step)
continue

View File

@ -0,0 +1,100 @@
import os
import re
import sys
from typing import Dict, List, Optional, Tuple
now_dir = os.getcwd()
sys.path.append(now_dir)
from text.LangSegmenter import LangSegmenter
from text import cleaned_text_to_sequence
from text.cleaner import clean_text
PreparedTextSegmentPayload = Dict[str, object]
def split_text_by_language(text: str, language: str) -> Tuple[List[str], List[str]]:
textlist: List[str] = []
langlist: List[str] = []
if language == "all_zh":
for tmp in LangSegmenter.getTexts(text, "zh"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_yue":
for tmp in LangSegmenter.getTexts(text, "zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text, "ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text, "ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or (
tmp["lang"] != "en" and langlist[-1] != "en"
)
if same_group:
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
langlist.append(language)
textlist.append(tmp["text"])
return textlist, langlist
def clean_text_segment(text: str, language: str, version: str) -> Tuple[List[int], Optional[List[int]], str]:
normalized_language = language.replace("all_", "")
phones, word2ph, norm_text = clean_text(text, normalized_language, version)
phones = cleaned_text_to_sequence(phones, version)
return list(phones), None if word2ph is None else list(word2ph), str(norm_text)
def preprocess_text_segments_payload(
text: str,
language: str,
version: str,
final: bool = False,
) -> List[PreparedTextSegmentPayload]:
text = re.sub(r" {2,}", " ", text)
textlist, langlist = split_text_by_language(text, language)
payloads: List[PreparedTextSegmentPayload] = []
total_phones_len = 0
for segment_text, segment_lang in zip(textlist, langlist):
phones, word2ph, norm_text = clean_text_segment(segment_text, segment_lang, version)
payloads.append(
{
"language": segment_lang.replace("all_", ""),
"phones": phones,
"word2ph": word2ph,
"norm_text": norm_text,
}
)
total_phones_len += len(phones)
if not final and total_phones_len < 6:
return preprocess_text_segments_payload("." + text, language, version, final=True)
return payloads

View File

@ -2,6 +2,7 @@ import warnings
warnings.filterwarnings("ignore")
import math
from typing import List
import torch
from torch import nn
@ -1038,6 +1039,67 @@ class SynthesizerTrn(nn.Module):
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o
@torch.no_grad()
def decode_batched_request_local(
self,
codes: torch.Tensor,
code_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
refer_list: List[torch.Tensor],
noise_scale: float = 0.5,
speed: float = 1,
sv_emb: torch.Tensor | None = None,
):
batch_size = int(codes.size(1))
if batch_size <= 0:
raise ValueError("decode_batched_request_local 收到空 batch")
if len(refer_list) != batch_size:
raise ValueError("refer_list 数量与 batch size 不一致")
refer_lengths = torch.LongTensor([int(item.size(2)) for item in refer_list]).to(codes.device)
max_refer_len = int(refer_lengths.max().item())
refer_batch = torch.zeros(
(batch_size, int(refer_list[0].size(1)), max_refer_len),
dtype=refer_list[0].dtype,
device=codes.device,
)
for batch_index, refer in enumerate(refer_list):
refer_batch[batch_index, :, : int(refer.size(2))] = refer.squeeze(0)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, max_refer_len), 1).to(refer_batch.dtype)
if self.version == "v1":
ge = self.ref_enc(refer_batch * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer_batch[:, :704] * refer_mask, refer_mask)
if self.is_v2pro:
if sv_emb is None:
raise ValueError("v2Pro batched request-local synthesis 缺少 sv_emb")
ge = ge + self.sv_emb(sv_emb).unsqueeze(-1)
ge = self.prelu(ge)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")
y_lengths = code_lengths.to(device=codes.device, dtype=torch.long) * 2
text_lengths = text_lengths.to(device=text.device, dtype=torch.long)
x, m_p, logs_p, y_mask, _, _ = self.enc_p(
quantized,
y_lengths,
text,
text_lengths,
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
speed,
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
audio = self.dec((z * y_mask)[:, :, :], g=ge)
upsample_factor = 1
for up_layer in self.dec.ups:
stride = up_layer.stride[0] if isinstance(up_layer.stride, tuple) else int(up_layer.stride)
upsample_factor *= int(stride)
audio_lengths = y_mask.squeeze(1).sum(dim=1).to(dtype=torch.long) * int(upsample_factor)
return audio, audio_lengths
@torch.no_grad()
def decode_streaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):

565
api_v3.py
View File

@ -107,6 +107,7 @@ import sys
import time
import traceback
import uuid
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, List, Union
@ -115,6 +116,10 @@ now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
from runtime_preload import preload_text_runtime_deps
preload_text_runtime_deps()
import argparse
import subprocess
import wave
@ -128,14 +133,15 @@ import uvicorn
from io import BytesIO
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import (
SchedulerRequestSpec,
T2SActiveBatch,
T2SFinishedItem,
T2SRunningRequest,
T2SRequestState,
prepare_request_state,
run_decode_step_for_running,
run_prefill_step,
merge_active_batches,
decode_one_step,
run_prefill_active_batch,
run_scheduler_continuous,
)
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
@ -238,39 +244,71 @@ class SchedulerPendingJob:
request_id: str
state: T2SRequestState
done_event: threading.Event
done_loop: asyncio.AbstractEventLoop | None
done_future: asyncio.Future | None
enqueue_time: float
speed_factor: float
sample_steps: int
media_type: str
prepare_ms: float = 0.0
prepare_wall_ms: float = 0.0
prepare_profile_total_ms: float = 0.0
first_schedule_time: float | None = None
prefill_ms: float = 0.0
merge_ms: float = 0.0
decode_ms: float = 0.0
finalize_wait_ms: float = 0.0
synth_ms: float = 0.0
pack_ms: float = 0.0
decode_steps: int = 0
result_ready_time: float | None = None
result: dict | None = None
sample_rate: int | None = None
audio_data: np.ndarray | None = None
error: str | None = None
@dataclass
class SchedulerFinalizeTask:
request_id: str
item: T2SFinishedItem
enqueued_time: float
class SchedulerDebugWorker:
def __init__(self, tts: TTS, max_steps: int = 1500, micro_batch_wait_ms: int = 5):
self.tts = tts
self.max_steps = max_steps
self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0
self.prepare_coordinator = PrepareCoordinator(tts)
self.condition = threading.Condition()
self.prepare_inflight = 0
self.prepare_peak_inflight = 0
self.finalize_condition = threading.Condition()
self.finalize_pending_tasks: deque[SchedulerFinalizeTask] = deque()
self.finalize_pending_peak = 0
self.finalize_inflight = 0
self.finalize_inflight_peak = 0
self.finalize_workers = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1)))
self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower()
self.finalize_batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16)))
self.finalize_batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0)
self.pending_jobs: List[SchedulerPendingJob] = []
self.running_requests: List[T2SRunningRequest] = []
self.active_batch: T2SActiveBatch | None = None
self.job_map: dict[str, SchedulerPendingJob] = {}
self.total_finished = 0
self.total_submitted = 0
self.worker_thread = threading.Thread(target=self._run_loop, name="t2s-scheduler-debug-worker", daemon=True)
self.worker_thread.start()
self.finalize_threads = [
threading.Thread(
target=self._run_finalize_loop,
name=f"t2s-scheduler-finalize-{worker_index}",
daemon=True,
)
for worker_index in range(self.finalize_workers)
]
for finalize_thread in self.finalize_threads:
finalize_thread.start()
def _sync_device(self) -> None:
try:
@ -283,20 +321,7 @@ class SchedulerDebugWorker:
pass
def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState:
with self.condition:
self.prepare_inflight += 1
prepare_inflight_on_enter = self.prepare_inflight
if self.prepare_inflight > self.prepare_peak_inflight:
self.prepare_peak_inflight = self.prepare_inflight
prepare_peak_inflight = self.prepare_peak_inflight
try:
state = prepare_request_state(self.tts, spec)
state.prepare_profile["worker_prepare_inflight_on_enter"] = float(prepare_inflight_on_enter)
state.prepare_profile["worker_prepare_peak_inflight"] = float(prepare_peak_inflight)
return state
finally:
with self.condition:
self.prepare_inflight = max(0, self.prepare_inflight - 1)
raise RuntimeError("prepare_state sync path has been replaced by PrepareCoordinator")
def submit(
self,
@ -304,27 +329,47 @@ class SchedulerDebugWorker:
speed_factor: float,
sample_steps: int,
media_type: str,
prepare_ms: float,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None = None,
done_future: asyncio.Future | None = None,
) -> SchedulerPendingJob:
job = SchedulerPendingJob(
request_id=state.request_id,
state=state,
done_event=threading.Event(),
done_loop=done_loop,
done_future=done_future,
enqueue_time=time.perf_counter(),
speed_factor=float(speed_factor),
sample_steps=int(sample_steps),
media_type=media_type,
prepare_ms=float(prepare_ms),
prepare_wall_ms=float(prepare_wall_ms),
prepare_profile_total_ms=float(prepare_profile_total_ms),
)
with self.condition:
self.pending_jobs.append(job)
self.job_map[job.request_id] = job
self.total_submitted += 1
self.condition.notify_all()
with self.finalize_condition:
self.finalize_condition.notify_all()
return job
async def prepare_state_async(self, spec: SchedulerRequestSpec) -> T2SRequestState:
state, _, _ = await self.prepare_coordinator.prepare_state_profiled_async(spec, time.perf_counter())
return state
async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]:
return await asyncio.gather(*[self.prepare_state_async(spec) for spec in specs])
async def prepare_state_profiled_async(
self,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
) -> tuple[T2SRequestState, float, float]:
return await self.prepare_coordinator.prepare_state_profiled_async(spec, prepare_submit_at)
def _mark_prefill_started(self, jobs: List[SchedulerPendingJob], started_at: float) -> None:
with self.condition:
for job in jobs:
@ -340,6 +385,14 @@ class SchedulerDebugWorker:
if tracked_job is not None:
tracked_job.prefill_ms += elapsed_ms
def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
elapsed_ms = elapsed_s * 1000.0
with self.condition:
for request_id in request_ids:
job = self.job_map.get(request_id)
if job is not None:
job.merge_ms += elapsed_ms
def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
elapsed_ms = elapsed_s * 1000.0
with self.condition:
@ -349,16 +402,30 @@ class SchedulerDebugWorker:
job.decode_ms += elapsed_ms
job.decode_steps += 1
def _add_finalize_wait_ms(self, request_ids: List[str], elapsed_ms: float) -> None:
with self.condition:
for request_id in request_ids:
job = self.job_map.get(request_id)
if job is not None:
job.finalize_wait_ms += elapsed_ms
def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]:
semantic_tokens = item.semantic_tokens.unsqueeze(0).unsqueeze(0).to(self.tts.configs.device)
phones = job.state.phones.unsqueeze(0).to(self.tts.configs.device)
semantic_tokens = item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0).to(self.tts.configs.device)
phones = job.state.phones.detach().clone().unsqueeze(0).to(self.tts.configs.device)
prompt_semantic = job.state.prompt_semantic.detach().clone()
prompt_phones = job.state.prompt_phones.detach().clone()
refer_spec = (
job.state.refer_spec[0].detach().clone(),
None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(),
)
raw_audio = job.state.raw_audio.detach().clone()
audio_fragment = self.tts.synthesize_audio_request_local(
semantic_tokens=semantic_tokens,
phones=phones,
prompt_semantic=job.state.prompt_semantic,
prompt_phones=job.state.prompt_phones,
refer_spec=job.state.refer_spec,
raw_audio=job.state.raw_audio,
prompt_semantic=prompt_semantic,
prompt_phones=prompt_phones,
refer_spec=refer_spec,
raw_audio=raw_audio,
raw_sr=job.state.raw_sr,
speed=float(job.speed_factor),
sample_steps=int(job.sample_steps),
@ -375,6 +442,11 @@ class SchedulerDebugWorker:
)
def get_state(self) -> dict:
with self.finalize_condition:
finalize_pending = len(self.finalize_pending_tasks)
finalize_pending_peak = self.finalize_pending_peak
finalize_inflight = self.finalize_inflight
finalize_inflight_peak = self.finalize_inflight_peak
with self.condition:
bert_stage = self.tts.prepare_bert_stage_limiter.snapshot()
ref_audio_stage = self.tts.prepare_ref_audio_stage_limiter.snapshot()
@ -388,12 +460,24 @@ class SchedulerDebugWorker:
if self.tts.prepare_ref_semantic_batch_worker is None
else self.tts.prepare_ref_semantic_batch_worker.snapshot()
)
prepare_coordinator_state = self.prepare_coordinator.snapshot()
return {
"pending_jobs": len(self.pending_jobs),
"running_requests": len(self.running_requests),
"prepare_inflight": self.prepare_inflight,
"prepare_peak_inflight": self.prepare_peak_inflight,
"running_requests": 0 if self.active_batch is None else len(self.active_batch.request_ids),
"prepare_inflight": prepare_coordinator_state["inflight"],
"prepare_peak_inflight": prepare_coordinator_state["peak_inflight"],
"finalize_pending": finalize_pending,
"finalize_pending_peak": finalize_pending_peak,
"finalize_inflight": finalize_inflight,
"finalize_inflight_peak": finalize_inflight_peak,
"finalize_workers": self.finalize_workers,
"finalize_mode": self.finalize_mode,
"finalize_batch_max_items": self.finalize_batch_max_items,
"finalize_batch_wait_ms": self.finalize_batch_wait_s * 1000.0,
"prepare_request_executor_workers": 0,
"prepare_text_cpu_workers": int(getattr(self.tts, "prepare_text_cpu_workers", 0)),
"prepare_text_feature_workers": int(prepare_coordinator_state["text_feature_workers"]),
"prepare_ref_audio_workers": int(prepare_coordinator_state["ref_audio_workers"]),
"prepare_bert_stage": bert_stage,
"prepare_bert_batch_worker": bert_batch_worker,
"prepare_ref_audio_stage": ref_audio_stage,
@ -405,59 +489,217 @@ class SchedulerDebugWorker:
"micro_batch_wait_ms": int(self.micro_batch_wait_s * 1000),
}
def _finalize_finished(self, items: List[T2SFinishedItem]) -> None:
def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None:
if not items:
return
jobs_to_finalize: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
tasks: List[SchedulerFinalizeTask] = []
enqueued_time = time.perf_counter()
with self.condition:
for item in items:
job = self.job_map.get(item.request_id)
if job is not None:
jobs_to_finalize.append((job, item))
tasks.append(
SchedulerFinalizeTask(
request_id=item.request_id,
item=item,
enqueued_time=enqueued_time,
)
)
if not tasks:
return
with self.finalize_condition:
self.finalize_pending_tasks.extend(tasks)
if len(self.finalize_pending_tasks) > self.finalize_pending_peak:
self.finalize_pending_peak = len(self.finalize_pending_tasks)
self.finalize_condition.notify_all()
for job, item in jobs_to_finalize:
@staticmethod
def _finalize_batch_key(job: SchedulerPendingJob) -> tuple[float, int]:
return (round(float(job.speed_factor), 6), int(job.sample_steps))
def _take_finalize_task_batch(self) -> List[SchedulerFinalizeTask]:
with self.finalize_condition:
while not self.finalize_pending_tasks:
self.finalize_condition.wait()
if self.finalize_mode == "after_t2s_drain":
while not self._is_t2s_drained():
self.finalize_condition.wait(timeout=self.micro_batch_wait_s)
task = self.finalize_pending_tasks.popleft()
selected_tasks = [task]
batch_key = None
with self.condition:
first_job = self.job_map.get(task.request_id)
if first_job is not None:
batch_key = self._finalize_batch_key(first_job)
batch_deadline = time.perf_counter() + self.finalize_batch_wait_s
while len(selected_tasks) < self.finalize_batch_max_items:
if batch_key is None:
break
matched_index = None
for pending_index, pending_task in enumerate(self.finalize_pending_tasks):
with self.condition:
pending_job = self.job_map.get(pending_task.request_id)
if pending_job is None:
matched_index = pending_index
break
if self._finalize_batch_key(pending_job) == batch_key:
matched_index = pending_index
break
if matched_index is not None:
selected_tasks.append(self.finalize_pending_tasks[matched_index])
del self.finalize_pending_tasks[matched_index]
continue
remaining = batch_deadline - time.perf_counter()
if remaining <= 0:
break
self.finalize_condition.wait(timeout=remaining)
self.finalize_inflight += len(selected_tasks)
if self.finalize_inflight > self.finalize_inflight_peak:
self.finalize_inflight_peak = self.finalize_inflight
return selected_tasks
def _finalize_task_done(self, count: int) -> None:
with self.finalize_condition:
self.finalize_inflight = max(0, self.finalize_inflight - count)
def _is_t2s_drained(self) -> bool:
with self.condition:
return (
self.active_batch is None
and not self.pending_jobs
and self.prepare_inflight <= 0
)
def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None:
finished_at = time.perf_counter()
with self.condition:
if self.job_map.get(item.request_id) is not job:
return
queue_wait_ms = 0.0
if job.first_schedule_time is not None:
queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0)
worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0)
worker_residual_ms = max(
0.0,
worker_total_ms
- queue_wait_ms
- job.prefill_ms
- job.merge_ms
- job.decode_ms
- job.finalize_wait_ms
- job.synth_ms,
)
worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms)
job.sample_rate = int(sample_rate)
job.audio_data = audio_data
job.result_ready_time = finished_at
prepare_profile = dict(job.state.prepare_profile)
job.result = {
"request_id": item.request_id,
"semantic_len": int(item.semantic_tokens.shape[0]),
"finish_idx": int(item.finish_idx),
"finish_reason": item.finish_reason,
"prepare_ms": job.prepare_wall_ms,
"prepare_wall_ms": job.prepare_wall_ms,
"prepare_profile_total_ms": job.prepare_profile_total_ms,
"prepare_profile": prepare_profile,
"queue_wait_ms": queue_wait_ms,
"prefill_ms": job.prefill_ms,
"merge_ms": job.merge_ms,
"decode_ms": job.decode_ms,
"finalize_wait_ms": job.finalize_wait_ms,
"synth_ms": job.synth_ms,
"worker_residual_ms": worker_residual_ms,
"worker_other_ms": worker_other_ms,
"worker_total_ms": worker_total_ms,
"decode_steps": int(job.decode_steps),
"sample_rate": int(sample_rate),
"media_type": job.media_type,
}
job.done_event.set()
self._notify_done_future(job)
self.job_map.pop(item.request_id, None)
self.total_finished += 1
def _synthesize_finished_audio_batch(
self,
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]],
) -> List[tuple[int, np.ndarray]]:
semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items]
phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items]
refer_specs = []
speeds = []
sample_steps_list = []
for job, _ in jobs_and_items:
refer_specs.append(
(
job.state.refer_spec[0].detach().clone(),
None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(),
)
)
speeds.append(float(job.speed_factor))
sample_steps_list.append(int(job.sample_steps))
audio_fragments = self.tts.synthesize_audio_requests_local_batched(
semantic_tokens_list=semantic_tokens_list,
phones_list=phones_list,
refer_specs=refer_specs,
speeds=speeds,
sample_steps_list=sample_steps_list,
)
output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"]
results: List[tuple[int, np.ndarray]] = []
for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments):
results.append(
self.tts.audio_postprocess(
audio=[[audio_fragment]],
sr=int(output_sr),
batch_index_list=None,
speed_factor=float(job.speed_factor),
split_bucket=False,
fragment_interval=0.0,
super_sampling=False,
)
)
return results
def _run_finalize_loop(self) -> None:
while True:
tasks = self._take_finalize_task_batch()
try:
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
finalize_wait_request_ids: List[str] = []
with self.condition:
for task in tasks:
job = self.job_map.get(task.request_id)
if job is None:
continue
jobs_and_items.append((job, task.item))
finalize_wait_request_ids.append(task.request_id)
if not jobs_and_items:
continue
now = time.perf_counter()
for task in tasks:
self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0))
self._sync_device()
synth_start = time.perf_counter()
sample_rate, audio_data = self._synthesize_finished_audio(job, item)
if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder:
job, item = jobs_and_items[0]
batch_results = [self._synthesize_finished_audio(job, item)]
else:
batch_results = self._synthesize_finished_audio_batch(jobs_and_items)
self._sync_device()
synth_ms = (time.perf_counter() - synth_start) * 1000.0
with self.condition:
for job, _ in jobs_and_items:
tracked_job = self.job_map.get(job.request_id)
if tracked_job is not None:
tracked_job.synth_ms += synth_ms
for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results):
self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data)
except Exception as exc:
self._finalize_error([item.request_id], str(exc))
continue
finished_at = time.perf_counter()
with self.condition:
if self.job_map.get(item.request_id) is not job:
continue
queue_wait_ms = 0.0
if job.first_schedule_time is not None:
queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0)
worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0)
job.synth_ms += synth_ms
job.sample_rate = int(sample_rate)
job.audio_data = audio_data
prepare_profile = dict(job.state.prepare_profile)
job.result = {
"request_id": item.request_id,
"semantic_len": int(item.semantic_tokens.shape[0]),
"finish_idx": int(item.finish_idx),
"finish_reason": item.finish_reason,
"prepare_ms": job.prepare_ms,
"prepare_wall_ms": job.prepare_wall_ms,
"prepare_profile": prepare_profile,
"queue_wait_ms": queue_wait_ms,
"prefill_ms": job.prefill_ms,
"decode_ms": job.decode_ms,
"synth_ms": job.synth_ms,
"worker_total_ms": worker_total_ms,
"decode_steps": int(job.decode_steps),
"sample_rate": int(sample_rate),
"media_type": job.media_type,
}
job.done_event.set()
self.job_map.pop(item.request_id, None)
self.total_finished += 1
self._finalize_error([task.request_id for task in tasks], str(exc))
finally:
self._finalize_task_done(len(tasks))
def _finalize_error(self, request_ids: List[str], error: str) -> None:
if not request_ids:
@ -469,12 +711,28 @@ class SchedulerDebugWorker:
continue
job.error = error
job.done_event.set()
self._notify_done_future(job)
self.job_map.pop(request_id, None)
self.total_finished += 1
@staticmethod
def _resolve_done_future(job: SchedulerPendingJob) -> None:
future = job.done_future
if future is None or future.done():
return
future.set_result(True)
def _notify_done_future(self, job: SchedulerPendingJob) -> None:
if job.done_loop is None or job.done_future is None:
return
try:
job.done_loop.call_soon_threadsafe(self._resolve_done_future, job)
except RuntimeError:
pass
def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
with self.condition:
if not self.pending_jobs and not self.running_requests:
if not self.pending_jobs and self.active_batch is None:
self.condition.wait(timeout=self.micro_batch_wait_s)
elif wait_for_batch and self.pending_jobs:
self.condition.wait(timeout=self.micro_batch_wait_s)
@ -482,11 +740,13 @@ class SchedulerDebugWorker:
return []
pending = list(self.pending_jobs)
self.pending_jobs.clear()
with self.finalize_condition:
self.finalize_condition.notify_all()
return pending
def _run_loop(self) -> None:
while True:
wait_for_batch = len(self.running_requests) == 0
wait_for_batch = self.active_batch is None
pending_jobs = self._take_pending_snapshot(wait_for_batch=wait_for_batch)
if pending_jobs:
@ -494,37 +754,54 @@ class SchedulerDebugWorker:
self._sync_device()
prefill_start = time.perf_counter()
self._mark_prefill_started(pending_jobs, prefill_start)
admitted_running, admitted_finished = run_prefill_step(
admitted_active_batch, admitted_finished = run_prefill_active_batch(
self.tts.t2s_model.model,
[job.state for job in pending_jobs],
max_steps=self.max_steps,
)
self._sync_device()
self._add_prefill_time(pending_jobs, time.perf_counter() - prefill_start)
self._finalize_finished(admitted_finished)
self.running_requests.extend(admitted_running)
self._enqueue_finalize_finished(admitted_finished)
merge_start = time.perf_counter()
self.active_batch = merge_active_batches(
self.tts.t2s_model.model,
self.active_batch,
admitted_active_batch,
)
self._add_merge_time(
[] if self.active_batch is None else list(self.active_batch.request_ids),
time.perf_counter() - merge_start,
)
with self.finalize_condition:
self.finalize_condition.notify_all()
except Exception as exc:
self._finalize_error([job.request_id for job in pending_jobs], str(exc))
if self.running_requests:
if self.active_batch is not None:
try:
active_request_ids = [item.state.request_id for item in self.running_requests]
active_request_ids = [state.request_id for state in self.active_batch.states]
self._sync_device()
decode_start = time.perf_counter()
self.running_requests, step_finished = run_decode_step_for_running(
self.active_batch, step_finished = decode_one_step(
self.tts.t2s_model.model,
self.running_requests,
self.active_batch,
max_steps=self.max_steps,
)
self._sync_device()
self._add_decode_time(active_request_ids, time.perf_counter() - decode_start)
self._finalize_finished(step_finished)
self._enqueue_finalize_finished(step_finished)
with self.finalize_condition:
self.finalize_condition.notify_all()
except Exception as exc:
self._finalize_error(active_request_ids, str(exc))
self.running_requests = []
self.active_batch = None
with self.finalize_condition:
self.finalize_condition.notify_all()
continue
if not pending_jobs:
with self.finalize_condition:
self.finalize_condition.notify_all()
time.sleep(self.micro_batch_wait_s)
@ -788,10 +1065,6 @@ def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
]
def prepare_scheduler_states_batch(specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]:
return [scheduler_debug_worker.prepare_state(spec) for spec in specs]
def build_scheduler_submit_spec(request: Scheduler_Submit_Request) -> SchedulerRequestSpec:
payload = request.dict()
request_id = payload["request_id"] or f"job_{uuid.uuid4().hex[:12]}"
@ -845,7 +1118,7 @@ async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request):
try:
set_scheduler_seed(request.seed)
specs = build_scheduler_request_specs(request.requests)
states = await asyncio.to_thread(prepare_scheduler_states_batch, specs)
states = await scheduler_debug_worker.prepare_states_batch_async(specs)
finished = run_scheduler_continuous(tts_pipeline.t2s_model.model, states, max_steps=int(request.max_steps))
return JSONResponse(
status_code=200,
@ -867,20 +1140,51 @@ async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request):
async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
try:
request_start = time.perf_counter()
prepare_start = request_start
spec = build_scheduler_submit_spec(request)
prepare_start = time.perf_counter()
state = await asyncio.to_thread(scheduler_debug_worker.prepare_state, spec)
prepare_wall_ms = (time.perf_counter() - prepare_start) * 1000.0
prepare_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms))
spec_ready_at = time.perf_counter()
prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0)
state, prepare_exec_started_at, prepare_exec_finished_at = await scheduler_debug_worker.prepare_state_profiled_async(
spec,
spec_ready_at,
)
prepare_end = time.perf_counter()
prepare_wall_ms = (prepare_end - prepare_start) * 1000.0
prepare_profile_total_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms))
prepare_profile_wall_ms = float(state.prepare_profile.get("wall_total_ms", prepare_profile_total_ms))
prepare_executor_queue_ms = float(
state.prepare_profile.get("executor_queue_ms", max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0))
)
prepare_executor_run_ms = float(
state.prepare_profile.get(
"executor_run_wall_ms",
max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0),
)
)
prepare_other_ms = max(
0.0,
prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_profile_wall_ms,
)
loop = asyncio.get_running_loop()
done_future = loop.create_future()
job = scheduler_debug_worker.submit(
state,
speed_factor=float(request.speed_factor),
sample_steps=int(request.sample_steps),
media_type=request.media_type,
prepare_ms=prepare_ms,
prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=loop,
done_future=done_future,
)
timeout_ok = await asyncio.to_thread(job.done_event.wait, float(request.timeout_sec))
api_after_prepare_ms = max(0.0, (job.enqueue_time - prepare_end) * 1000.0)
timeout_ok = False
try:
await asyncio.wait_for(asyncio.shield(done_future), timeout=float(request.timeout_sec))
timeout_ok = True
except asyncio.TimeoutError:
timeout_ok = False
wait_return_at = time.perf_counter()
if not timeout_ok:
return JSONResponse(
status_code=202,
@ -888,8 +1192,10 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
"message": "queued",
"request_id": job.request_id,
"timings": {
"prepare_ms": prepare_ms,
"prepare_ms": prepare_wall_ms,
"prepare_wall_ms": prepare_wall_ms,
"prepare_profile_total_ms": prepare_profile_total_ms,
"api_after_prepare_ms": api_after_prepare_ms,
"request_elapsed_ms": max(0.0, (time.perf_counter() - request_start) * 1000.0),
},
"worker_state": scheduler_debug_worker.get_state(),
@ -911,9 +1217,13 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
)
pack_start = time.perf_counter()
audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue()
pack_ms = (time.perf_counter() - pack_start) * 1000.0
pack_end = time.perf_counter()
pack_ms = (pack_end - pack_start) * 1000.0
job.pack_ms = pack_ms
request_total_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0)
api_wait_result_ms = 0.0
if job.result_ready_time is not None:
api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0)
worker_total_ms = float(job.result["worker_total_ms"]) if job.result is not None else 0.0
headers = {
"X-Request-Id": job.request_id,
"X-Semantic-Len": str(job.result["semantic_len"]) if job.result is not None else "0",
@ -921,16 +1231,32 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
"X-Queue-Wait-Ms": (
f"{float(job.result['queue_wait_ms']):.3f}" if job.result is not None else "0.000"
),
"X-Prepare-Ms": f"{prepare_ms:.3f}",
"X-Prepare-Ms": f"{prepare_wall_ms:.3f}",
"X-Prepare-Wall-Ms": f"{prepare_wall_ms:.3f}",
"X-Prepare-Spec-Build-Ms": f"{prepare_spec_build_ms:.3f}",
"X-Prepare-Executor-Queue-Ms": f"{prepare_executor_queue_ms:.3f}",
"X-Prepare-Admission-Wait-Ms": (
f"{float(job.result['prepare_profile'].get('prepare_admission_wait_ms', 0.0)):.3f}"
if job.result is not None
else "0.000"
),
"X-Prepare-Executor-Run-Ms": f"{prepare_executor_run_ms:.3f}",
"X-Prepare-Profile-Total-Ms": f"{prepare_profile_total_ms:.3f}",
"X-Prepare-Profile-Wall-Ms": f"{prepare_profile_wall_ms:.3f}",
"X-Prepare-Other-Ms": f"{prepare_other_ms:.3f}",
"X-Api-After-Prepare-Ms": f"{api_after_prepare_ms:.3f}",
"X-Prefill-Ms": f"{float(job.result['prefill_ms']):.3f}" if job.result is not None else "0.000",
"X-Merge-Ms": f"{float(job.result['merge_ms']):.3f}" if job.result is not None else "0.000",
"X-Decode-Ms": f"{float(job.result['decode_ms']):.3f}" if job.result is not None else "0.000",
"X-Finalize-Wait-Ms": f"{float(job.result['finalize_wait_ms']):.3f}" if job.result is not None else "0.000",
"X-Synth-Ms": f"{float(job.result['synth_ms']):.3f}" if job.result is not None else "0.000",
"X-Worker-Residual-Ms": f"{float(job.result['worker_residual_ms']):.3f}" if job.result is not None else "0.000",
"X-Worker-Other-Ms": f"{float(job.result['worker_other_ms']):.3f}" if job.result is not None else "0.000",
"X-Pack-Ms": f"{pack_ms:.3f}",
"X-Worker-Total-Ms": (
f"{float(job.result['worker_total_ms']):.3f}" if job.result is not None else "0.000"
),
"X-Request-Total-Ms": f"{request_total_ms:.3f}",
"X-Api-Wait-Result-Ms": f"{api_wait_result_ms:.3f}",
"X-Decode-Steps": str(job.result["decode_steps"]) if job.result is not None else "0",
}
if job.result is not None:
@ -939,16 +1265,48 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
{
"X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}",
"X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('prompt_text_cpu_preprocess_ms', 0.0)):.3f}",
"X-Prepare-Target-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('text_cpu_preprocess_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('prompt_text_cpu_queue_ms', 0.0)):.3f}",
"X-Prepare-Target-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('text_cpu_queue_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('prompt_text_feature_queue_ms', 0.0)):.3f}",
"X-Prepare-Target-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('text_feature_queue_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}",
"X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_admission_wait_ms', 0.0)):.3f}",
"X-Prepare-Target-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('text_bert_admission_wait_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_queue_wait_ms', 0.0)):.3f}",
"X-Prepare-Target-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('text_bert_queue_wait_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_collect_wait_ms', 0.0)):.3f}",
"X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('text_bert_batch_collect_wait_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}",
"X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(
int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))
),
"X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(
int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))
),
"X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(
int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))
),
"X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(
int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))
),
"X-Prepare-Prompt-Bert-High-Pressure-Peak": str(
int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))
),
"X-Prepare-Target-Bert-High-Pressure-Peak": str(
int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))
),
"X-Prepare-Prompt-Bert-Batch-Size-Peak": str(
int(prepare_profile.get("prompt_text_bert_batch_size_peak", 0.0))
),
"X-Prepare-Target-Bert-Batch-Size-Peak": str(
int(prepare_profile.get("text_bert_batch_size_peak", 0.0))
),
"X-Prepare-Prompt-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_window_ms', 0.0)):.3f}",
"X-Prepare-Target-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('text_bert_batch_window_ms', 0.0)):.3f}",
"X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}",
"X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))),
"X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}",
@ -964,13 +1322,22 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
"X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}",
"X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}",
"X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}",
"X-Prepare-Profile-Wall-Ms": f"{float(prepare_profile.get('wall_total_ms', 0.0)):.3f}",
"X-Prepare-Inflight-On-Enter": str(
int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0))
),
"X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))),
}
)
response_ready_at = time.perf_counter()
response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0)
request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0)
request_other_ms = max(
0.0,
request_total_ms - prepare_wall_ms - api_after_prepare_ms - worker_total_ms - api_wait_result_ms - pack_ms,
)
headers["X-Response-Overhead-Ms"] = f"{response_overhead_ms:.3f}"
headers["X-Request-Other-Ms"] = f"{request_other_ms:.3f}"
headers["X-Request-Total-Ms"] = f"{request_total_ms:.3f}"
return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers)
except Exception as e:
return JSONResponse(