mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-21 02:22:42 +08:00
Refactor TTS and scheduler components to enhance text processing and batching capabilities. Introduce PrepareCoordinator for managing text feature preparation asynchronously, and update SchedulerDebugWorker to support new finalize task management. Implement batch processing in PrepareBertBatchWorker with improved admission control and profiling metrics. Add text CPU preprocessing utilities for better text segmentation and normalization.
This commit is contained in:
parent
a45e171ff5
commit
827d6ea47c
@ -1,26 +1,27 @@
|
||||
import gc
|
||||
import concurrent.futures
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
|
||||
import torchaudio
|
||||
from tqdm import tqdm
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
import os
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from runtime_preload import preload_text_runtime_deps
|
||||
|
||||
preload_text_runtime_deps()
|
||||
|
||||
import ffmpeg
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import yaml
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from BigVGAN.bigvgan import BigVGAN
|
||||
@ -30,6 +31,7 @@ from module.models import SynthesizerTrn, SynthesizerTrnV3, Generator
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
from tqdm import tqdm
|
||||
|
||||
from tools.audio_sr import AP_BWE
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
@ -449,20 +451,21 @@ class TTS:
|
||||
"overlapped_len": None,
|
||||
}
|
||||
self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1")))
|
||||
self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "2")))
|
||||
self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4")))
|
||||
self.prepare_bert_batch_worker = None
|
||||
self.prepare_ref_semantic_batch_worker = None
|
||||
default_text_cpu_workers = 16
|
||||
self.prepare_text_cpu_workers = max(
|
||||
0,
|
||||
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", str(default_text_cpu_workers))),
|
||||
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", "0")),
|
||||
)
|
||||
self.prepare_text_cpu_executor = None
|
||||
if self.prepare_text_cpu_workers > 0:
|
||||
self.prepare_text_cpu_executor = ThreadPoolExecutor(
|
||||
self.prepare_text_cpu_executor = (
|
||||
concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=self.prepare_text_cpu_workers,
|
||||
thread_name_prefix="prepare-text-cpu",
|
||||
)
|
||||
if self.prepare_text_cpu_workers > 0
|
||||
else None
|
||||
)
|
||||
|
||||
self._init_models()
|
||||
|
||||
@ -475,6 +478,20 @@ class TTS:
|
||||
batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_WINDOW_MS", "5")),
|
||||
max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_ITEMS", "16")),
|
||||
max_batch_tokens=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_TOKENS", "4096")),
|
||||
max_pending_tasks=int(os.environ.get("GPTSOVITS_PREPARE_BERT_MAX_PENDING_TASKS", "0")),
|
||||
admission_poll_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_ADMISSION_POLL_MS", "1")),
|
||||
high_pressure_pending_threshold=int(
|
||||
os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_PENDING_THRESHOLD", "0")
|
||||
),
|
||||
high_pressure_batch_window_ms=int(
|
||||
os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_BATCH_WINDOW_MS", "1")
|
||||
),
|
||||
high_pressure_max_batch_items=int(
|
||||
os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_MAX_ITEMS", "32")
|
||||
),
|
||||
high_pressure_max_batch_tokens=int(
|
||||
os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_MAX_TOKENS", "8192")
|
||||
),
|
||||
)
|
||||
if os.environ.get("GPTSOVITS_PREPARE_REF_BATCHING", "0") != "0":
|
||||
ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_SAMPLES")
|
||||
@ -830,13 +847,23 @@ class TTS:
|
||||
return codes[0, 0].to(self.configs.device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
|
||||
def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
|
||||
cpu_prepare_start = time.perf_counter()
|
||||
wav16k = prepare_prompt_semantic_wav16k(
|
||||
raw_audio=raw_audio,
|
||||
raw_sr=raw_sr,
|
||||
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
||||
)
|
||||
return self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
|
||||
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
|
||||
forward_start = time.perf_counter()
|
||||
prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
return prompt_semantic, cpu_prepare_ms, forward_ms
|
||||
|
||||
@torch.inference_mode()
|
||||
def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
|
||||
prompt_semantic, _, _ = self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr)
|
||||
return prompt_semantic
|
||||
|
||||
def extract_prompt_semantic(self, ref_wav_path: str):
|
||||
raw_audio, raw_sr = self._load_ref_audio_raw(ref_wav_path)
|
||||
@ -887,7 +914,9 @@ class TTS:
|
||||
if self.prepare_ref_semantic_batch_worker is None:
|
||||
with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats:
|
||||
prompt_semantic_start = time.perf_counter()
|
||||
prompt_semantic = self._extract_prompt_semantic_from_raw(raw_audio, raw_sr)
|
||||
prompt_semantic, prompt_semantic_cpu_prepare_ms, prompt_semantic_forward_ms = (
|
||||
self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr)
|
||||
)
|
||||
prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0
|
||||
ref_spec_start = time.perf_counter()
|
||||
refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
|
||||
@ -897,8 +926,8 @@ class TTS:
|
||||
audio_stage_inflight_peak = float(limiter_stats["peak_inflight"])
|
||||
prompt_semantic_profile = {
|
||||
"prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]),
|
||||
"prompt_semantic_cpu_prepare_ms": 0.0,
|
||||
"prompt_semantic_forward_ms": prompt_semantic_ms,
|
||||
"prompt_semantic_cpu_prepare_ms": float(prompt_semantic_cpu_prepare_ms),
|
||||
"prompt_semantic_forward_ms": float(prompt_semantic_forward_ms),
|
||||
"prompt_semantic_scatter_ms": 0.0,
|
||||
"prompt_semantic_stage_slots": float(limiter_stats["slots"]),
|
||||
"prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
|
||||
@ -1012,6 +1041,32 @@ class TTS:
|
||||
text, language, self.configs.version, profile=profile
|
||||
)
|
||||
|
||||
def prepare_text_segments(self, text: str, language: str):
|
||||
return self.text_preprocessor.preprocess_text_segments(text, language, self.configs.version)
|
||||
|
||||
def build_text_features_from_segments(self, prepared_segments, profile: dict | None = None):
|
||||
return self.text_preprocessor.build_phones_and_bert_from_segments(prepared_segments, profile=profile)
|
||||
|
||||
async def build_text_features_from_segments_async(self, prepared_segments, profile: dict | None = None):
|
||||
return await self.text_preprocessor.build_phones_and_bert_from_segments_async(
|
||||
prepared_segments,
|
||||
profile=profile,
|
||||
)
|
||||
|
||||
async def build_text_feature_pair_from_segments_async(
|
||||
self,
|
||||
prompt_segments,
|
||||
target_segments,
|
||||
prompt_profile: dict | None = None,
|
||||
target_profile: dict | None = None,
|
||||
):
|
||||
return await self.text_preprocessor.build_phones_and_bert_pair_from_segments_async(
|
||||
prompt_segments,
|
||||
target_segments,
|
||||
prompt_profile=prompt_profile,
|
||||
target_profile=target_profile,
|
||||
)
|
||||
|
||||
def _set_ref_audio_path(self, ref_audio_path):
|
||||
self.prompt_cache["ref_audio_path"] = ref_audio_path
|
||||
|
||||
@ -2011,6 +2066,79 @@ class TTS:
|
||||
sample_steps=sample_steps,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def synthesize_audio_requests_local_batched(
|
||||
self,
|
||||
semantic_tokens_list: List[torch.Tensor],
|
||||
phones_list: List[torch.Tensor],
|
||||
refer_specs: List[tuple],
|
||||
speeds: List[float] | None = None,
|
||||
sample_steps_list: List[int] | None = None,
|
||||
) -> List[torch.Tensor]:
|
||||
batch_size = len(semantic_tokens_list)
|
||||
if batch_size == 0:
|
||||
return []
|
||||
if len(phones_list) != batch_size or len(refer_specs) != batch_size:
|
||||
raise ValueError("batched request-local synthesis 输入长度不一致")
|
||||
if speeds is None:
|
||||
speeds = [1.0] * batch_size
|
||||
if sample_steps_list is None:
|
||||
sample_steps_list = [32] * batch_size
|
||||
if len(speeds) != batch_size or len(sample_steps_list) != batch_size:
|
||||
raise ValueError("batched request-local synthesis 参数长度不一致")
|
||||
first_speed = float(speeds[0])
|
||||
first_sample_steps = int(sample_steps_list[0])
|
||||
if any(abs(float(item) - first_speed) > 1e-6 for item in speeds):
|
||||
raise ValueError("batched request-local synthesis 目前要求 speed 一致")
|
||||
if any(int(item) != first_sample_steps for item in sample_steps_list):
|
||||
raise ValueError("batched request-local synthesis 目前要求 sample_steps 一致")
|
||||
if self.configs.use_vocoder:
|
||||
raise NotImplementedError("request-local batched VITS synthesis 暂不支持 vocoder 模型")
|
||||
|
||||
device = self.configs.device
|
||||
max_semantic_len = max(int(item.shape[-1]) for item in semantic_tokens_list)
|
||||
max_phone_len = max(int(item.shape[-1]) for item in phones_list)
|
||||
semantic_batch = torch.zeros((1, batch_size, max_semantic_len), dtype=torch.long, device=device)
|
||||
phone_batch = torch.zeros((batch_size, max_phone_len), dtype=torch.long, device=device)
|
||||
semantic_lengths = []
|
||||
phone_lengths = []
|
||||
refer_audio_specs: List[torch.Tensor] = []
|
||||
sv_emb_batch = None
|
||||
sv_emb_list: List[torch.Tensor] = []
|
||||
|
||||
for batch_index, semantic_tokens in enumerate(semantic_tokens_list):
|
||||
semantic_len = int(semantic_tokens.shape[-1])
|
||||
phone_len = int(phones_list[batch_index].shape[-1])
|
||||
semantic_batch[0, batch_index, :semantic_len] = semantic_tokens.to(device=device, dtype=torch.long)
|
||||
phone_batch[batch_index, :phone_len] = phones_list[batch_index].to(device=device, dtype=torch.long)
|
||||
semantic_lengths.append(semantic_len)
|
||||
phone_lengths.append(phone_len)
|
||||
|
||||
refer_audio_spec, audio_tensor = refer_specs[batch_index]
|
||||
refer_audio_specs.append(refer_audio_spec.to(dtype=self.precision, device=device))
|
||||
if self.is_v2pro:
|
||||
if audio_tensor is None:
|
||||
raise ValueError(i18n("v2Pro request-local batched synthesis 缺少 16k 参考音频"))
|
||||
sv_emb_list.append(self.sv_model.compute_embedding3(audio_tensor).to(device))
|
||||
|
||||
if self.is_v2pro:
|
||||
sv_emb_batch = torch.cat(sv_emb_list, dim=0)
|
||||
|
||||
audio_batch, audio_lengths = self.vits_model.decode_batched_request_local(
|
||||
codes=semantic_batch,
|
||||
code_lengths=torch.LongTensor(semantic_lengths).to(device),
|
||||
text=phone_batch,
|
||||
text_lengths=torch.LongTensor(phone_lengths).to(device),
|
||||
refer_list=refer_audio_specs,
|
||||
speed=first_speed,
|
||||
sv_emb=sv_emb_batch,
|
||||
)
|
||||
audios: List[torch.Tensor] = []
|
||||
for batch_index in range(batch_size):
|
||||
audio_len = int(audio_lengths[batch_index].item())
|
||||
audios.append(audio_batch[batch_index, 0, :audio_len].detach())
|
||||
return audios
|
||||
|
||||
def using_vocoder_synthesis_batched_infer(
|
||||
self,
|
||||
idx_list: List[int],
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -13,12 +15,13 @@ import re
|
||||
import torch
|
||||
from text.LangSegmenter import LangSegmenter
|
||||
from text import chinese
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from text.cleaner import clean_text
|
||||
from text import cleaned_text_to_sequence
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||
from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker
|
||||
from TTS_infer_pack.text_cpu_preprocess import preprocess_text_segments_payload
|
||||
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
@ -92,6 +95,14 @@ class StageLimiter:
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreparedTextSegment:
|
||||
language: str
|
||||
phones: List[int]
|
||||
word2ph: Optional[List[int]]
|
||||
norm_text: str
|
||||
|
||||
|
||||
class TextPreprocessor:
|
||||
def __init__(
|
||||
self,
|
||||
@ -149,7 +160,7 @@ class TextPreprocessor:
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if len(text.strip()) == 0:
|
||||
continue
|
||||
if not re.sub("\W+", "", text):
|
||||
if not re.sub(r"\W+", "", text):
|
||||
# 检测一下,如果是纯符号,就跳过。
|
||||
continue
|
||||
if text[-1] not in splits:
|
||||
@ -168,7 +179,8 @@ class TextPreprocessor:
|
||||
def segment_and_extract_feature_for_text(
|
||||
self, text: str, language: str, version: str = "v1", profile: Dict | None = None
|
||||
) -> Tuple[list, torch.Tensor, str]:
|
||||
return self.get_phones_and_bert(text, language, version, profile=profile)
|
||||
prepared_segments = self.preprocess_text_segments(text, language, version)
|
||||
return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile)
|
||||
|
||||
def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]:
|
||||
textlist = []
|
||||
@ -223,24 +235,49 @@ class TextPreprocessor:
|
||||
def get_phones_and_bert(
|
||||
self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None
|
||||
):
|
||||
text = re.sub(r' {2,}', ' ', text)
|
||||
textlist, langlist = self._split_text_by_language(text, language)
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
for segment_text, segment_lang in zip(textlist, langlist):
|
||||
phones, word2ph, norm_text = self.clean_text_inf(segment_text, segment_lang, version)
|
||||
bert = self.get_bert_inf(phones, word2ph, norm_text, segment_lang, profile=profile)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
prepared_segments = self.preprocess_text_segments(text, language, version, final=final)
|
||||
return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile)
|
||||
|
||||
def preprocess_text_segments(
|
||||
self,
|
||||
text: str,
|
||||
language: str,
|
||||
version: str,
|
||||
final: bool = False,
|
||||
) -> List[PreparedTextSegment]:
|
||||
payloads = preprocess_text_segments_payload(text, language, version, final=final)
|
||||
return [
|
||||
PreparedTextSegment(
|
||||
language=str(payload["language"]),
|
||||
phones=list(payload["phones"]),
|
||||
word2ph=None if payload["word2ph"] is None else list(payload["word2ph"]),
|
||||
norm_text=str(payload["norm_text"]),
|
||||
)
|
||||
for payload in payloads
|
||||
]
|
||||
|
||||
def build_phones_and_bert_from_segments(
|
||||
self,
|
||||
prepared_segments: List[PreparedTextSegment],
|
||||
profile: Dict | None = None,
|
||||
) -> Tuple[list, torch.Tensor, str]:
|
||||
phones_list: List[List[int]] = []
|
||||
bert_list: List[torch.Tensor] = []
|
||||
norm_text_list: List[str] = []
|
||||
for segment in prepared_segments:
|
||||
bert = self.get_bert_inf(
|
||||
segment.phones,
|
||||
segment.word2ph,
|
||||
segment.norm_text,
|
||||
segment.language,
|
||||
profile=profile,
|
||||
)
|
||||
phones_list.append(segment.phones)
|
||||
norm_text_list.append(segment.norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = "".join(norm_text_list)
|
||||
|
||||
if not final and len(phones) < 6:
|
||||
return self.get_phones_and_bert("." + text, language, version, final=True, profile=profile)
|
||||
|
||||
return phones, bert, norm_text
|
||||
|
||||
def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None:
|
||||
@ -253,21 +290,41 @@ class TextPreprocessor:
|
||||
return
|
||||
profile[key] = float(max(float(profile.get(key, 0.0)), float(value)))
|
||||
|
||||
def _merge_bert_worker_profile(self, profile: Dict | None, worker_profile: Dict[str, float]) -> None:
|
||||
self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_admission_wait_ms", worker_profile.get("bert_admission_wait_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_queue_wait_ms", worker_profile.get("bert_queue_wait_ms", 0.0))
|
||||
self._accumulate_profile(
|
||||
profile,
|
||||
"bert_batch_collect_wait_ms",
|
||||
worker_profile.get("bert_batch_collect_wait_ms", 0.0),
|
||||
)
|
||||
self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0))
|
||||
self._update_profile_peak(profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0))
|
||||
self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0))
|
||||
self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0))
|
||||
self._update_profile_peak(
|
||||
profile,
|
||||
"bert_pending_depth_on_enqueue_peak",
|
||||
worker_profile.get("bert_pending_depth_on_enqueue", 0.0),
|
||||
)
|
||||
self._update_profile_peak(
|
||||
profile,
|
||||
"bert_pending_depth_on_collect_peak",
|
||||
worker_profile.get("bert_pending_depth_on_collect", 0.0),
|
||||
)
|
||||
self._update_profile_peak(profile, "bert_high_pressure_mode_peak", worker_profile.get("bert_high_pressure_mode", 0.0))
|
||||
if profile is not None:
|
||||
profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0))
|
||||
profile["bert_batch_window_ms"] = float(worker_profile.get("bert_batch_window_ms", 0.0))
|
||||
|
||||
def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor:
|
||||
if self.bert_batch_worker is not None:
|
||||
feature, worker_profile = self.bert_batch_worker.submit(text, word2ph)
|
||||
self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0))
|
||||
self._update_profile_peak(
|
||||
profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0)
|
||||
)
|
||||
self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0))
|
||||
self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0))
|
||||
if profile is not None:
|
||||
profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0))
|
||||
self._merge_bert_worker_profile(profile, worker_profile)
|
||||
return feature
|
||||
|
||||
limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0}
|
||||
@ -310,9 +367,18 @@ class TextPreprocessor:
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str, profile: Dict | None = None):
|
||||
def get_bert_inf(
|
||||
self,
|
||||
phones: list,
|
||||
word2ph: Optional[list],
|
||||
norm_text: str,
|
||||
language: str,
|
||||
profile: Dict | None = None,
|
||||
):
|
||||
language = language.replace("all_", "")
|
||||
if language == "zh":
|
||||
if word2ph is None:
|
||||
raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征")
|
||||
feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device)
|
||||
else:
|
||||
feature = torch.zeros(
|
||||
@ -322,6 +388,112 @@ class TextPreprocessor:
|
||||
|
||||
return feature
|
||||
|
||||
async def build_phones_and_bert_from_segments_async(
|
||||
self,
|
||||
prepared_segments: List[PreparedTextSegment],
|
||||
profile: Dict | None = None,
|
||||
) -> Tuple[list, torch.Tensor, str]:
|
||||
segment_jobs = self._build_async_segment_jobs(prepared_segments, profile)
|
||||
pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = []
|
||||
for segment_index, segment in enumerate(prepared_segments):
|
||||
if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None:
|
||||
continue
|
||||
if segment.word2ph is None:
|
||||
raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征")
|
||||
pending_items.append(
|
||||
(
|
||||
segment_jobs["bert_list"],
|
||||
segment_index,
|
||||
profile,
|
||||
self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph),
|
||||
)
|
||||
)
|
||||
|
||||
if pending_items:
|
||||
pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items])
|
||||
for (bert_list, bert_index, item_profile, _), (feature, worker_profile) in zip(pending_items, pending_results):
|
||||
self._merge_bert_worker_profile(item_profile, worker_profile)
|
||||
bert_list[bert_index] = feature.to(self.device)
|
||||
|
||||
return self._finalize_async_segment_jobs(segment_jobs)
|
||||
|
||||
def _build_async_segment_jobs(
|
||||
self,
|
||||
prepared_segments: List[PreparedTextSegment],
|
||||
profile: Dict | None,
|
||||
) -> Dict[str, List]:
|
||||
phones_list: List[List[int]] = []
|
||||
bert_list: List[torch.Tensor | None] = []
|
||||
norm_text_list: List[str] = []
|
||||
|
||||
for segment in prepared_segments:
|
||||
phones_list.append(segment.phones)
|
||||
norm_text_list.append(segment.norm_text)
|
||||
segment_language = segment.language.replace("all_", "")
|
||||
if segment_language == "zh" and self.bert_batch_worker is not None:
|
||||
if segment.word2ph is None:
|
||||
raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征")
|
||||
bert_list.append(None)
|
||||
continue
|
||||
bert_list.append(
|
||||
self.get_bert_inf(
|
||||
segment.phones,
|
||||
segment.word2ph,
|
||||
segment.norm_text,
|
||||
segment.language,
|
||||
profile=profile,
|
||||
)
|
||||
)
|
||||
return {
|
||||
"phones_list": phones_list,
|
||||
"bert_list": bert_list,
|
||||
"norm_text_list": norm_text_list,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _finalize_async_segment_jobs(segment_jobs: Dict[str, List]) -> Tuple[list, torch.Tensor, str]:
|
||||
bert = torch.cat([feature for feature in segment_jobs["bert_list"] if feature is not None], dim=1)
|
||||
phones = sum(segment_jobs["phones_list"], [])
|
||||
norm_text = "".join(segment_jobs["norm_text_list"])
|
||||
return phones, bert, norm_text
|
||||
|
||||
async def build_phones_and_bert_pair_from_segments_async(
|
||||
self,
|
||||
prompt_segments: List[PreparedTextSegment],
|
||||
target_segments: List[PreparedTextSegment],
|
||||
prompt_profile: Dict | None = None,
|
||||
target_profile: Dict | None = None,
|
||||
) -> Tuple[Tuple[list, torch.Tensor, str], Tuple[list, torch.Tensor, str]]:
|
||||
prompt_jobs = self._build_async_segment_jobs(prompt_segments, prompt_profile)
|
||||
target_jobs = self._build_async_segment_jobs(target_segments, target_profile)
|
||||
pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = []
|
||||
|
||||
for segment_jobs, prepared_segments, profile in (
|
||||
(prompt_jobs, prompt_segments, prompt_profile),
|
||||
(target_jobs, target_segments, target_profile),
|
||||
):
|
||||
for segment_index, segment in enumerate(prepared_segments):
|
||||
if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None:
|
||||
continue
|
||||
if segment.word2ph is None:
|
||||
raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征")
|
||||
pending_items.append(
|
||||
(
|
||||
segment_jobs["bert_list"],
|
||||
segment_index,
|
||||
profile,
|
||||
self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph),
|
||||
)
|
||||
)
|
||||
|
||||
if pending_items:
|
||||
pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items])
|
||||
for (bert_list, bert_index, profile, _), (feature, worker_profile) in zip(pending_items, pending_results):
|
||||
self._merge_bert_worker_profile(profile, worker_profile)
|
||||
bert_list[bert_index] = feature.to(self.device)
|
||||
|
||||
return self._finalize_async_segment_jobs(prompt_jobs), self._finalize_async_segment_jobs(target_jobs)
|
||||
|
||||
def filter_text(self, texts):
|
||||
_text = []
|
||||
if all(text in [None, " ", "\n", ""] for text in texts):
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
@ -14,7 +15,12 @@ class BertFeatureTask:
|
||||
word2ph: List[int]
|
||||
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
|
||||
created_at: float = field(default_factory=time.perf_counter)
|
||||
enqueued_at: float = 0.0
|
||||
admission_wait_ms: float = 0.0
|
||||
pending_depth_on_enqueue: int = 0
|
||||
done_event: threading.Event = field(default_factory=threading.Event)
|
||||
done_loop: asyncio.AbstractEventLoop | None = None
|
||||
done_future: asyncio.Future | None = None
|
||||
result_feature: torch.Tensor | None = None
|
||||
error: Exception | None = None
|
||||
profile: Dict[str, float] = field(default_factory=dict)
|
||||
@ -30,14 +36,37 @@ class PrepareBertBatchWorker:
|
||||
batch_window_ms: int = 5,
|
||||
max_batch_items: int = 16,
|
||||
max_batch_tokens: int = 4096,
|
||||
max_pending_tasks: int = 0,
|
||||
admission_poll_ms: int = 1,
|
||||
high_pressure_pending_threshold: int = 0,
|
||||
high_pressure_batch_window_ms: int | None = None,
|
||||
high_pressure_max_batch_items: int | None = None,
|
||||
high_pressure_max_batch_tokens: int | None = None,
|
||||
):
|
||||
self.bert_model = bert_model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.stage_limiter = stage_limiter
|
||||
self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0)
|
||||
self.batch_window_ms = max(0, int(batch_window_ms))
|
||||
self.batch_window_s = float(self.batch_window_ms) / 1000.0
|
||||
self.max_batch_items = max(1, int(max_batch_items))
|
||||
self.max_batch_tokens = max(16, int(max_batch_tokens))
|
||||
self.max_pending_tasks = max(0, int(max_pending_tasks))
|
||||
self.admission_poll_s = max(0.0005, float(max(1, int(admission_poll_ms))) / 1000.0)
|
||||
|
||||
self.high_pressure_pending_threshold = max(
|
||||
0,
|
||||
int(high_pressure_pending_threshold)
|
||||
if int(high_pressure_pending_threshold) > 0
|
||||
else max(self.max_batch_items * 2, 32),
|
||||
)
|
||||
hp_window_ms = self.batch_window_ms if high_pressure_batch_window_ms is None else int(high_pressure_batch_window_ms)
|
||||
hp_items = self.max_batch_items if high_pressure_max_batch_items is None else int(high_pressure_max_batch_items)
|
||||
hp_tokens = self.max_batch_tokens if high_pressure_max_batch_tokens is None else int(high_pressure_max_batch_tokens)
|
||||
self.high_pressure_batch_window_ms = max(0, hp_window_ms)
|
||||
self.high_pressure_batch_window_s = float(self.high_pressure_batch_window_ms) / 1000.0
|
||||
self.high_pressure_max_batch_items = max(self.max_batch_items, hp_items)
|
||||
self.high_pressure_max_batch_tokens = max(self.max_batch_tokens, hp_tokens)
|
||||
|
||||
self.condition = threading.Condition()
|
||||
self.pending_tasks: Deque[BertFeatureTask] = deque()
|
||||
@ -47,26 +76,70 @@ class PrepareBertBatchWorker:
|
||||
self.total_batches = 0
|
||||
self.active_batch_size = 0
|
||||
self.active_batch_peak = 0
|
||||
self.active_batch_tokens = 0
|
||||
self.active_batch_tokens_peak = 0
|
||||
self.high_pressure_batches = 0
|
||||
self.admission_wait_total_ms = 0.0
|
||||
self.admission_wait_peak_ms = 0.0
|
||||
self.worker_thread = threading.Thread(target=self._run_loop, name="prepare-bert-batch-worker", daemon=True)
|
||||
self.worker_thread.start()
|
||||
|
||||
def _estimate_task_tokens(self, task: BertFeatureTask) -> int:
|
||||
return max(1, len(task.norm_text) + 2)
|
||||
|
||||
def _can_enqueue_locked(self) -> bool:
|
||||
if self.max_pending_tasks <= 0:
|
||||
return True
|
||||
return (len(self.pending_tasks) + self.active_batch_size) < self.max_pending_tasks
|
||||
|
||||
def _record_enqueue_locked(self, task: BertFeatureTask, admission_wait_ms: float) -> None:
|
||||
task.admission_wait_ms = float(max(0.0, admission_wait_ms))
|
||||
task.enqueued_at = time.perf_counter()
|
||||
task.pending_depth_on_enqueue = int(len(self.pending_tasks))
|
||||
self.pending_tasks.append(task)
|
||||
self.total_submitted += 1
|
||||
self.admission_wait_total_ms += task.admission_wait_ms
|
||||
self.admission_wait_peak_ms = max(self.admission_wait_peak_ms, task.admission_wait_ms)
|
||||
if len(self.pending_tasks) > self.pending_peak:
|
||||
self.pending_peak = len(self.pending_tasks)
|
||||
self.condition.notify_all()
|
||||
|
||||
def _enqueue_task(self, task: BertFeatureTask) -> None:
|
||||
admission_started = time.perf_counter()
|
||||
with self.condition:
|
||||
while not self._can_enqueue_locked():
|
||||
self.condition.wait(timeout=self.admission_poll_s)
|
||||
self._record_enqueue_locked(task, (time.perf_counter() - admission_started) * 1000.0)
|
||||
|
||||
async def _enqueue_task_async(self, task: BertFeatureTask) -> None:
|
||||
admission_started = time.perf_counter()
|
||||
while True:
|
||||
with self.condition:
|
||||
if self._can_enqueue_locked():
|
||||
self._record_enqueue_locked(task, (time.perf_counter() - admission_started) * 1000.0)
|
||||
return
|
||||
await asyncio.sleep(self.admission_poll_s)
|
||||
|
||||
def submit(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
task = BertFeatureTask(norm_text=str(norm_text), word2ph=list(word2ph))
|
||||
with self.condition:
|
||||
self.pending_tasks.append(task)
|
||||
self.total_submitted += 1
|
||||
if len(self.pending_tasks) > self.pending_peak:
|
||||
self.pending_peak = len(self.pending_tasks)
|
||||
self.condition.notify_all()
|
||||
self._enqueue_task(task)
|
||||
task.done_event.wait()
|
||||
if task.error is not None:
|
||||
raise task.error
|
||||
assert task.result_feature is not None
|
||||
return task.result_feature, dict(task.profile)
|
||||
|
||||
async def submit_async(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
loop = asyncio.get_running_loop()
|
||||
task = BertFeatureTask(
|
||||
norm_text=str(norm_text),
|
||||
word2ph=list(word2ph),
|
||||
done_loop=loop,
|
||||
done_future=loop.create_future(),
|
||||
)
|
||||
await self._enqueue_task_async(task)
|
||||
return await task.done_future
|
||||
|
||||
def snapshot(self) -> Dict[str, int]:
|
||||
with self.condition:
|
||||
return {
|
||||
@ -77,21 +150,57 @@ class PrepareBertBatchWorker:
|
||||
"total_batches": self.total_batches,
|
||||
"active_batch_size": self.active_batch_size,
|
||||
"active_batch_peak": self.active_batch_peak,
|
||||
"active_batch_tokens": self.active_batch_tokens,
|
||||
"active_batch_tokens_peak": self.active_batch_tokens_peak,
|
||||
"batch_window_ms": int(self.batch_window_s * 1000.0),
|
||||
"max_batch_items": self.max_batch_items,
|
||||
"max_batch_tokens": self.max_batch_tokens,
|
||||
"max_pending_tasks": self.max_pending_tasks,
|
||||
"high_pressure_pending_threshold": self.high_pressure_pending_threshold,
|
||||
"high_pressure_batch_window_ms": self.high_pressure_batch_window_ms,
|
||||
"high_pressure_max_batch_items": self.high_pressure_max_batch_items,
|
||||
"high_pressure_max_batch_tokens": self.high_pressure_max_batch_tokens,
|
||||
"high_pressure_batches": self.high_pressure_batches,
|
||||
"admission_wait_total_ms": self.admission_wait_total_ms,
|
||||
"admission_wait_peak_ms": self.admission_wait_peak_ms,
|
||||
}
|
||||
|
||||
def _collect_batch(self) -> List[BertFeatureTask]:
|
||||
def _select_batch_policy_locked(self) -> Tuple[float, int, int, bool, int]:
|
||||
pending_depth = len(self.pending_tasks)
|
||||
use_high_pressure = (
|
||||
self.high_pressure_pending_threshold > 0
|
||||
and pending_depth >= self.high_pressure_pending_threshold
|
||||
)
|
||||
if use_high_pressure:
|
||||
return (
|
||||
self.high_pressure_batch_window_s,
|
||||
self.high_pressure_max_batch_items,
|
||||
self.high_pressure_max_batch_tokens,
|
||||
True,
|
||||
pending_depth,
|
||||
)
|
||||
return (
|
||||
self.batch_window_s,
|
||||
self.max_batch_items,
|
||||
self.max_batch_tokens,
|
||||
False,
|
||||
pending_depth,
|
||||
)
|
||||
|
||||
def _collect_batch(self) -> Tuple[List[BertFeatureTask], Dict[str, float]]:
|
||||
with self.condition:
|
||||
while not self.pending_tasks:
|
||||
self.condition.wait()
|
||||
|
||||
collect_started = time.perf_counter()
|
||||
batch_window_s, max_batch_items, max_batch_tokens, use_high_pressure, pending_depth_on_collect = (
|
||||
self._select_batch_policy_locked()
|
||||
)
|
||||
batch: List[BertFeatureTask] = [self.pending_tasks.popleft()]
|
||||
batch_tokens = self._estimate_task_tokens(batch[0])
|
||||
deadline = time.perf_counter() + self.batch_window_s
|
||||
deadline = time.perf_counter() + batch_window_s
|
||||
|
||||
while len(batch) < self.max_batch_items:
|
||||
while len(batch) < max_batch_items:
|
||||
remaining = deadline - time.perf_counter()
|
||||
if remaining <= 0:
|
||||
break
|
||||
@ -100,26 +209,39 @@ class PrepareBertBatchWorker:
|
||||
continue
|
||||
next_task = self.pending_tasks[0]
|
||||
next_tokens = self._estimate_task_tokens(next_task)
|
||||
if len(batch) >= self.max_batch_items or (batch_tokens + next_tokens) > self.max_batch_tokens:
|
||||
if len(batch) >= max_batch_items or (batch_tokens + next_tokens) > max_batch_tokens:
|
||||
break
|
||||
batch.append(self.pending_tasks.popleft())
|
||||
batch_tokens += next_tokens
|
||||
|
||||
self.active_batch_size = len(batch)
|
||||
self.active_batch_tokens = batch_tokens
|
||||
if self.active_batch_size > self.active_batch_peak:
|
||||
self.active_batch_peak = self.active_batch_size
|
||||
return batch
|
||||
if self.active_batch_tokens > self.active_batch_tokens_peak:
|
||||
self.active_batch_tokens_peak = self.active_batch_tokens
|
||||
if use_high_pressure:
|
||||
self.high_pressure_batches += 1
|
||||
return batch, {
|
||||
"collect_wait_ms": (time.perf_counter() - collect_started) * 1000.0,
|
||||
"batch_tokens": float(batch_tokens),
|
||||
"pending_depth_on_collect": float(pending_depth_on_collect),
|
||||
"high_pressure_mode": 1.0 if use_high_pressure else 0.0,
|
||||
"batch_window_ms": float(self.high_pressure_batch_window_ms if use_high_pressure else self.batch_window_ms),
|
||||
}
|
||||
|
||||
def _finalize_batch(self, batch: List[BertFeatureTask]) -> None:
|
||||
with self.condition:
|
||||
self.active_batch_size = 0
|
||||
self.active_batch_tokens = 0
|
||||
self.total_batches += 1
|
||||
self.total_finished += len(batch)
|
||||
self.condition.notify_all()
|
||||
|
||||
def _run_batch(self, batch: List[BertFeatureTask]) -> None:
|
||||
def _run_batch(self, batch: List[BertFeatureTask], batch_meta: Dict[str, float]) -> None:
|
||||
batch_started = time.perf_counter()
|
||||
texts = [task.norm_text for task in batch]
|
||||
batch_tokens = sum(self._estimate_task_tokens(task) for task in batch)
|
||||
batch_tokens = int(batch_meta["batch_tokens"])
|
||||
|
||||
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
|
||||
if self.stage_limiter is None:
|
||||
@ -167,6 +289,9 @@ class PrepareBertBatchWorker:
|
||||
task.result_feature = torch.cat(phone_level_feature, dim=0).T
|
||||
task.profile = {
|
||||
"bert_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]),
|
||||
"bert_admission_wait_ms": float(task.admission_wait_ms),
|
||||
"bert_queue_wait_ms": max(0.0, (batch_started - task.enqueued_at) * 1000.0),
|
||||
"bert_batch_collect_wait_ms": float(batch_meta["collect_wait_ms"]),
|
||||
"bert_forward_ms": float(forward_ms),
|
||||
"bert_tokenize_ms": float(tokenize_ms),
|
||||
"bert_scatter_ms": 0.0,
|
||||
@ -175,6 +300,10 @@ class PrepareBertBatchWorker:
|
||||
"bert_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
|
||||
"bert_batch_size": float(len(batch)),
|
||||
"bert_batch_tokens": float(batch_tokens),
|
||||
"bert_pending_depth_on_enqueue": float(task.pending_depth_on_enqueue),
|
||||
"bert_pending_depth_on_collect": float(batch_meta["pending_depth_on_collect"]),
|
||||
"bert_high_pressure_mode": float(batch_meta["high_pressure_mode"]),
|
||||
"bert_batch_window_ms": float(batch_meta["batch_window_ms"]),
|
||||
}
|
||||
except Exception as exc: # noqa: PERF203
|
||||
task.error = exc
|
||||
@ -183,15 +312,35 @@ class PrepareBertBatchWorker:
|
||||
if task.result_feature is not None:
|
||||
task.profile["bert_scatter_ms"] = float(scatter_ms)
|
||||
task.done_event.set()
|
||||
self._notify_done_future(task)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_done_future(task: BertFeatureTask) -> None:
|
||||
if task.done_future is None or task.done_future.done():
|
||||
return
|
||||
if task.error is not None:
|
||||
task.done_future.set_exception(task.error)
|
||||
return
|
||||
assert task.result_feature is not None
|
||||
task.done_future.set_result((task.result_feature, dict(task.profile)))
|
||||
|
||||
def _notify_done_future(self, task: BertFeatureTask) -> None:
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_done_future, task)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
while True:
|
||||
batch = self._collect_batch()
|
||||
batch, batch_meta = self._collect_batch()
|
||||
try:
|
||||
self._run_batch(batch)
|
||||
self._run_batch(batch, batch_meta)
|
||||
except Exception as exc: # noqa: PERF203
|
||||
for task in batch:
|
||||
task.error = exc
|
||||
task.done_event.set()
|
||||
self._notify_done_future(task)
|
||||
finally:
|
||||
self._finalize_batch(batch)
|
||||
|
||||
294
GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py
Normal file
294
GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py
Normal file
@ -0,0 +1,294 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import (
|
||||
PreparedTextFeatures,
|
||||
SchedulerRequestSpec,
|
||||
T2SRequestState,
|
||||
build_request_state_from_parts,
|
||||
normalize_sentence,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfiledResult:
|
||||
result: Any
|
||||
submit_at: float
|
||||
started_at: float
|
||||
finished_at: float
|
||||
|
||||
@property
|
||||
def queue_ms(self) -> float:
|
||||
return max(0.0, (self.started_at - self.submit_at) * 1000.0)
|
||||
|
||||
@property
|
||||
def run_ms(self) -> float:
|
||||
return max(0.0, (self.finished_at - self.started_at) * 1000.0)
|
||||
|
||||
|
||||
class PrepareCoordinator:
|
||||
def __init__(self, tts: Any):
|
||||
self.tts = tts
|
||||
self.lock = threading.Lock()
|
||||
self.inflight = 0
|
||||
self.peak_inflight = 0
|
||||
self.use_async_text_feature_path = bool(
|
||||
getattr(tts, "prepare_bert_batch_worker", None) is not None
|
||||
and os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_DIRECT", "0") != "0"
|
||||
)
|
||||
self.max_inflight = max(0, int(os.environ.get("GPTSOVITS_PREPARE_MAX_INFLIGHT", "0")))
|
||||
self._inflight_semaphore = asyncio.Semaphore(self.max_inflight) if self.max_inflight > 0 else None
|
||||
self.text_feature_workers = 0
|
||||
self.text_feature_executor = None
|
||||
if not self.use_async_text_feature_path:
|
||||
text_feature_default_workers = max(1, int(getattr(tts, "prepare_text_cpu_workers", 16) or 16))
|
||||
self.text_feature_workers = max(
|
||||
1,
|
||||
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_WORKERS", str(text_feature_default_workers))),
|
||||
)
|
||||
self.text_feature_executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=self.text_feature_workers,
|
||||
thread_name_prefix="prepare-text-feature",
|
||||
)
|
||||
ref_audio_default_workers = max(1, int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4")))
|
||||
self.ref_audio_workers = max(
|
||||
1,
|
||||
int(os.environ.get("GPTSOVITS_PREPARE_REF_ASYNC_WORKERS", str(ref_audio_default_workers))),
|
||||
)
|
||||
self.ref_audio_executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=self.ref_audio_workers,
|
||||
thread_name_prefix="prepare-ref-audio",
|
||||
)
|
||||
|
||||
def _mark_enter(self) -> Tuple[int, int]:
|
||||
with self.lock:
|
||||
self.inflight += 1
|
||||
current_inflight = self.inflight
|
||||
if current_inflight > self.peak_inflight:
|
||||
self.peak_inflight = current_inflight
|
||||
return current_inflight, self.peak_inflight
|
||||
|
||||
def _mark_leave(self) -> None:
|
||||
with self.lock:
|
||||
self.inflight = max(0, self.inflight - 1)
|
||||
|
||||
def snapshot(self) -> Dict[str, int]:
|
||||
with self.lock:
|
||||
return {
|
||||
"inflight": int(self.inflight),
|
||||
"peak_inflight": int(self.peak_inflight),
|
||||
"max_inflight": int(self.max_inflight),
|
||||
"text_feature_workers": int(self.text_feature_workers),
|
||||
"ref_audio_workers": int(self.ref_audio_workers),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _run_profiled(fn, submit_at: float, *args) -> ProfiledResult:
|
||||
started_at = time.perf_counter()
|
||||
result = fn(*args)
|
||||
finished_at = time.perf_counter()
|
||||
return ProfiledResult(
|
||||
result=result,
|
||||
submit_at=float(submit_at),
|
||||
started_at=float(started_at),
|
||||
finished_at=float(finished_at),
|
||||
)
|
||||
|
||||
def _prepare_text_cpu(self, text: str, language: str):
|
||||
return self.tts.prepare_text_segments(text, language)
|
||||
|
||||
def _build_text_features(self, prepared_segments, language: str, cpu_run_ms: float) -> PreparedTextFeatures:
|
||||
profile: Dict[str, float] = {"cpu_preprocess_ms": float(cpu_run_ms)}
|
||||
branch_start = time.perf_counter()
|
||||
phones, bert_features, norm_text = self.tts.build_text_features_from_segments(prepared_segments, profile=profile)
|
||||
total_ms = float(cpu_run_ms + (time.perf_counter() - branch_start) * 1000.0)
|
||||
profile["bert_total_ms"] = max(0.0, total_ms - float(cpu_run_ms))
|
||||
return PreparedTextFeatures(
|
||||
phones=phones,
|
||||
bert_features=bert_features,
|
||||
norm_text=norm_text,
|
||||
profile=profile,
|
||||
total_ms=total_ms,
|
||||
cpu_preprocess_ms=float(cpu_run_ms),
|
||||
)
|
||||
|
||||
async def _run_on_executor(self, executor, fn, *args) -> ProfiledResult:
|
||||
loop = asyncio.get_running_loop()
|
||||
submit_at = time.perf_counter()
|
||||
return await loop.run_in_executor(executor, self._run_profiled, fn, float(submit_at), *args)
|
||||
|
||||
async def _run_text_cpu_stage(self, text: str, language: str) -> ProfiledResult:
|
||||
executor = getattr(self.tts, "prepare_text_cpu_executor", None)
|
||||
if executor is None:
|
||||
submit_at = time.perf_counter()
|
||||
return self._run_profiled(self._prepare_text_cpu, submit_at, text, language)
|
||||
return await self._run_on_executor(executor, self._prepare_text_cpu, text, language)
|
||||
|
||||
async def _run_text_feature_stage(self, prepared_segments, language: str, cpu_run_ms: float) -> ProfiledResult:
|
||||
return await self._run_on_executor(self.text_feature_executor, self._build_text_features, prepared_segments, language, cpu_run_ms)
|
||||
|
||||
@staticmethod
|
||||
def _estimate_text_feature_run_ms(profile: Dict[str, float]) -> float:
|
||||
return float(
|
||||
profile.get("bert_wait_ms", 0.0)
|
||||
+ profile.get("bert_tokenize_ms", 0.0)
|
||||
+ profile.get("bert_forward_ms", 0.0)
|
||||
+ profile.get("bert_scatter_ms", 0.0)
|
||||
)
|
||||
|
||||
async def _run_text_feature_pair_stage(
|
||||
self,
|
||||
prompt_segments,
|
||||
target_segments,
|
||||
prompt_cpu_run_ms: float,
|
||||
target_cpu_run_ms: float,
|
||||
) -> tuple[ProfiledResult, ProfiledResult]:
|
||||
if self.text_feature_executor is not None:
|
||||
prompt_feature_task = asyncio.create_task(
|
||||
self._run_text_feature_stage(prompt_segments, None, prompt_cpu_run_ms)
|
||||
)
|
||||
target_feature_task = asyncio.create_task(
|
||||
self._run_text_feature_stage(target_segments, None, target_cpu_run_ms)
|
||||
)
|
||||
return await asyncio.gather(prompt_feature_task, target_feature_task)
|
||||
|
||||
prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)}
|
||||
target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)}
|
||||
submit_at = time.perf_counter()
|
||||
started_at = float(submit_at)
|
||||
prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async(
|
||||
prompt_segments,
|
||||
target_segments,
|
||||
prompt_profile=prompt_profile,
|
||||
target_profile=target_profile,
|
||||
)
|
||||
finished_at = time.perf_counter()
|
||||
|
||||
prompt_result = PreparedTextFeatures(
|
||||
phones=prompt_result_raw[0],
|
||||
bert_features=prompt_result_raw[1],
|
||||
norm_text=prompt_result_raw[2],
|
||||
profile=prompt_profile,
|
||||
total_ms=float(prompt_cpu_run_ms + self._estimate_text_feature_run_ms(prompt_profile)),
|
||||
cpu_preprocess_ms=float(prompt_cpu_run_ms),
|
||||
)
|
||||
target_result = PreparedTextFeatures(
|
||||
phones=target_result_raw[0],
|
||||
bert_features=target_result_raw[1],
|
||||
norm_text=target_result_raw[2],
|
||||
profile=target_profile,
|
||||
total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)),
|
||||
cpu_preprocess_ms=float(target_cpu_run_ms),
|
||||
)
|
||||
prompt_profiled = ProfiledResult(
|
||||
result=prompt_result,
|
||||
submit_at=float(submit_at),
|
||||
started_at=started_at,
|
||||
finished_at=float(submit_at + self._estimate_text_feature_run_ms(prompt_profile) / 1000.0),
|
||||
)
|
||||
target_profiled = ProfiledResult(
|
||||
result=target_result,
|
||||
submit_at=float(submit_at),
|
||||
started_at=started_at,
|
||||
finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0),
|
||||
)
|
||||
if finished_at > prompt_profiled.finished_at:
|
||||
prompt_result.profile["bert_total_ms"] = max(
|
||||
self._estimate_text_feature_run_ms(prompt_profile),
|
||||
(finished_at - submit_at) * 1000.0,
|
||||
)
|
||||
target_result.profile["bert_total_ms"] = max(
|
||||
self._estimate_text_feature_run_ms(target_profile),
|
||||
(finished_at - submit_at) * 1000.0,
|
||||
)
|
||||
else:
|
||||
prompt_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(prompt_profile)
|
||||
target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile)
|
||||
return prompt_profiled, target_profiled
|
||||
|
||||
async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult:
|
||||
return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path)
|
||||
|
||||
async def prepare_state_profiled_async(
|
||||
self,
|
||||
spec: SchedulerRequestSpec,
|
||||
prepare_submit_at: float,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
admission_start = time.perf_counter()
|
||||
if self._inflight_semaphore is not None:
|
||||
await self._inflight_semaphore.acquire()
|
||||
prepare_admission_wait_ms = max(0.0, (time.perf_counter() - admission_start) * 1000.0)
|
||||
current_inflight, peak_inflight = self._mark_enter()
|
||||
prepare_start = time.perf_counter()
|
||||
prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang)
|
||||
text = spec.text.strip("\n")
|
||||
try:
|
||||
text_pair_start = time.perf_counter()
|
||||
prompt_cpu_task = asyncio.create_task(self._run_text_cpu_stage(prompt_text, spec.prompt_lang))
|
||||
target_cpu_task = asyncio.create_task(self._run_text_cpu_stage(text, spec.text_lang))
|
||||
ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(spec.ref_audio_path)))
|
||||
prompt_cpu_profiled, target_cpu_profiled = await asyncio.gather(prompt_cpu_task, target_cpu_task)
|
||||
text_feature_pair_task = asyncio.create_task(
|
||||
self._run_text_feature_pair_stage(
|
||||
prompt_cpu_profiled.result,
|
||||
target_cpu_profiled.result,
|
||||
prompt_cpu_profiled.run_ms,
|
||||
target_cpu_profiled.run_ms,
|
||||
)
|
||||
)
|
||||
(prompt_feature_profiled, target_feature_profiled), ref_audio_profiled = await asyncio.gather(
|
||||
text_feature_pair_task,
|
||||
ref_audio_task,
|
||||
)
|
||||
text_pair_end = time.perf_counter()
|
||||
state = build_request_state_from_parts(
|
||||
tts=self.tts,
|
||||
spec=spec,
|
||||
prompt_text=prompt_text,
|
||||
text=text,
|
||||
prompt_result=prompt_feature_profiled.result,
|
||||
target_result=target_feature_profiled.result,
|
||||
ref_audio_bundle=ref_audio_profiled.result,
|
||||
prepare_start=prepare_start,
|
||||
prepare_sync_start=prepare_start,
|
||||
profile_overrides={
|
||||
"executor_queue_ms": max(0.0, (prepare_start - prepare_submit_at) * 1000.0),
|
||||
"prepare_admission_wait_ms": prepare_admission_wait_ms,
|
||||
"executor_run_wall_ms": max(0.0, (time.perf_counter() - prepare_start) * 1000.0),
|
||||
"text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0),
|
||||
"prompt_text_parallel_future_wait_ms": 0.0,
|
||||
"prompt_text_parallel_future_executor_queue_ms": 0.0,
|
||||
"prompt_text_parallel_future_run_ms": 0.0,
|
||||
"prompt_text_parallel_future_finish_after_submit_ms": 0.0,
|
||||
"prompt_text_parallel_future_queue_tail_after_target_ms": 0.0,
|
||||
"prompt_text_parallel_future_run_tail_after_target_ms": 0.0,
|
||||
"prompt_text_cpu_queue_ms": prompt_cpu_profiled.queue_ms,
|
||||
"prompt_text_cpu_run_ms": prompt_cpu_profiled.run_ms,
|
||||
"prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms,
|
||||
"prompt_text_feature_run_ms": prompt_feature_profiled.run_ms,
|
||||
"text_cpu_queue_ms": target_cpu_profiled.queue_ms,
|
||||
"text_cpu_run_ms": target_cpu_profiled.run_ms,
|
||||
"text_feature_queue_ms": target_feature_profiled.queue_ms,
|
||||
"text_feature_run_ms": target_feature_profiled.run_ms,
|
||||
"ref_audio_task_queue_ms": ref_audio_profiled.queue_ms,
|
||||
"ref_audio_task_run_ms": ref_audio_profiled.run_ms,
|
||||
"worker_prepare_inflight_on_enter": float(current_inflight),
|
||||
"worker_prepare_peak_inflight": float(peak_inflight),
|
||||
},
|
||||
)
|
||||
prepare_exec_finished_at = time.perf_counter()
|
||||
state.prepare_profile["executor_run_wall_ms"] = max(
|
||||
0.0, (prepare_exec_finished_at - prepare_start) * 1000.0
|
||||
)
|
||||
return state, prepare_start, prepare_exec_finished_at
|
||||
finally:
|
||||
self._mark_leave()
|
||||
if self._inflight_semaphore is not None:
|
||||
self._inflight_semaphore.release()
|
||||
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import time
|
||||
@ -89,20 +88,31 @@ class T2SFinishedItem:
|
||||
class T2SActiveBatch:
|
||||
request_ids: List[str]
|
||||
states: List[T2SRequestState]
|
||||
x: torch.Tensor
|
||||
x_lens: torch.LongTensor
|
||||
x: Optional[torch.Tensor]
|
||||
x_lens: Optional[torch.LongTensor]
|
||||
y_sequences: List[torch.LongTensor]
|
||||
prefix_lens: torch.LongTensor
|
||||
xy_pos: torch.Tensor
|
||||
key_padding_mask: torch.Tensor
|
||||
prefill_attn_mask: torch.Tensor
|
||||
key_padding_mask: Optional[torch.Tensor]
|
||||
prefill_attn_mask: Optional[torch.Tensor]
|
||||
decode_attn_mask: Optional[torch.Tensor]
|
||||
k_cache: Optional[List[torch.Tensor]]
|
||||
v_cache: Optional[List[torch.Tensor]]
|
||||
step_idx: int
|
||||
kv_lens: Optional[torch.LongTensor]
|
||||
step_indices: torch.LongTensor
|
||||
prefill_done: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreparedTextFeatures:
|
||||
phones: List[int]
|
||||
bert_features: torch.Tensor
|
||||
norm_text: str
|
||||
profile: Dict[str, float]
|
||||
total_ms: float
|
||||
cpu_preprocess_ms: float
|
||||
|
||||
|
||||
def normalize_sentence(text: str, language: str) -> str:
|
||||
text = text.strip("\n").strip()
|
||||
if not text:
|
||||
@ -113,105 +123,125 @@ def normalize_sentence(text: str, language: str) -> str:
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def prepare_request_state(
|
||||
def prepare_text_features(
|
||||
tts: Any,
|
||||
text: str,
|
||||
language: str,
|
||||
) -> PreparedTextFeatures:
|
||||
device = tts.configs.device
|
||||
profile: Dict[str, float] = {}
|
||||
branch_start = time.perf_counter()
|
||||
_sync_device(device)
|
||||
cpu_start = time.perf_counter()
|
||||
prepared_segments = tts.prepare_text_segments(text, language)
|
||||
_sync_device(device)
|
||||
cpu_preprocess_ms = (time.perf_counter() - cpu_start) * 1000.0
|
||||
profile["cpu_preprocess_ms"] = float(cpu_preprocess_ms)
|
||||
bert_start = time.perf_counter()
|
||||
phones, bert_features, norm_text = tts.build_text_features_from_segments(prepared_segments, profile=profile)
|
||||
_sync_device(device)
|
||||
profile["bert_total_ms"] = (time.perf_counter() - bert_start) * 1000.0
|
||||
total_ms = (time.perf_counter() - branch_start) * 1000.0
|
||||
return PreparedTextFeatures(
|
||||
phones=phones,
|
||||
bert_features=bert_features,
|
||||
norm_text=norm_text,
|
||||
profile=profile,
|
||||
total_ms=float(total_ms),
|
||||
cpu_preprocess_ms=float(cpu_preprocess_ms),
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def build_request_state_from_parts(
|
||||
tts: Any,
|
||||
spec: SchedulerRequestSpec,
|
||||
prompt_text: str,
|
||||
text: str,
|
||||
prompt_result: PreparedTextFeatures,
|
||||
target_result: PreparedTextFeatures,
|
||||
ref_audio_bundle: Dict[str, Any],
|
||||
prepare_start: float,
|
||||
prepare_sync_start: float,
|
||||
profile_overrides: Optional[Dict[str, float]] = None,
|
||||
) -> T2SRequestState:
|
||||
device = tts.configs.device
|
||||
prepare_start = time.perf_counter()
|
||||
_sync_device(device)
|
||||
prepare_sync_start = time.perf_counter()
|
||||
prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang)
|
||||
text = spec.text.strip("\n")
|
||||
|
||||
prompt_text_profile: Dict[str, float] = {}
|
||||
text_features_profile: Dict[str, float] = {}
|
||||
text_feature_pair_start = time.perf_counter()
|
||||
prompt_future: Future | None = None
|
||||
|
||||
def _extract_prompt_features():
|
||||
_sync_device(device)
|
||||
prompt_start = time.perf_counter()
|
||||
result = tts.extract_text_features(prompt_text, spec.prompt_lang, profile=prompt_text_profile)
|
||||
_sync_device(device)
|
||||
return result, (time.perf_counter() - prompt_start) * 1000.0
|
||||
|
||||
if getattr(tts, "prepare_text_cpu_executor", None) is not None:
|
||||
prompt_future = tts.prepare_text_cpu_executor.submit(_extract_prompt_features)
|
||||
|
||||
_sync_device(device)
|
||||
text_features_start = time.perf_counter()
|
||||
phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang, profile=text_features_profile)
|
||||
_sync_device(device)
|
||||
text_features_ms = (time.perf_counter() - text_features_start) * 1000.0
|
||||
|
||||
if prompt_future is None:
|
||||
_sync_device(device)
|
||||
prompt_text_features_start = time.perf_counter()
|
||||
prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features(
|
||||
prompt_text, spec.prompt_lang, profile=prompt_text_profile
|
||||
)
|
||||
_sync_device(device)
|
||||
prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0
|
||||
prompt_text_profile["parallel_future_wait_ms"] = 0.0
|
||||
else:
|
||||
prompt_wait_start = time.perf_counter()
|
||||
(prompt_phones, prompt_bert_features, prompt_norm_text), prompt_text_features_ms = prompt_future.result()
|
||||
prompt_text_profile["parallel_future_wait_ms"] = (time.perf_counter() - prompt_wait_start) * 1000.0
|
||||
|
||||
text_feature_pair_ms = (time.perf_counter() - text_feature_pair_start) * 1000.0
|
||||
if phones is None:
|
||||
raise ValueError(f"{spec.request_id} text preprocessing returned no phones")
|
||||
|
||||
_sync_device(device)
|
||||
ref_audio_bundle_start = time.perf_counter()
|
||||
ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path))
|
||||
ref_audio_bundle_ms = float(ref_audio_bundle.get("profile", {}).get("bundle_total_ms", 0.0))
|
||||
bundle_profile = ref_audio_bundle.get("profile", {})
|
||||
prompt_semantic = ref_audio_bundle["prompt_semantic"].long()
|
||||
spec_audio, audio_16k = ref_audio_bundle["refer_spec"]
|
||||
raw_audio = ref_audio_bundle["raw_audio"]
|
||||
raw_sr = int(ref_audio_bundle["raw_sr"])
|
||||
_sync_device(device)
|
||||
ref_audio_bundle_ms = (time.perf_counter() - ref_audio_bundle_start) * 1000.0
|
||||
bundle_profile = ref_audio_bundle.get("profile", {})
|
||||
prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms))
|
||||
ref_spec_ms = float(bundle_profile.get("ref_spec_ms", 0.0))
|
||||
audio_load_ms = float(bundle_profile.get("audio_load_ms", 0.0))
|
||||
|
||||
_sync_device(device)
|
||||
tensorize_start = time.perf_counter()
|
||||
phones_tensor = torch.LongTensor(phones).to(tts.configs.device)
|
||||
prompt_phones_tensor = torch.LongTensor(prompt_phones).to(tts.configs.device)
|
||||
all_phones = torch.LongTensor(prompt_phones + phones).to(tts.configs.device)
|
||||
all_bert_features = torch.cat([prompt_bert_features, bert_features], dim=1).to(
|
||||
phones_tensor = torch.LongTensor(target_result.phones).to(tts.configs.device)
|
||||
prompt_phones_tensor = torch.LongTensor(prompt_result.phones).to(tts.configs.device)
|
||||
all_phones = torch.LongTensor(prompt_result.phones + target_result.phones).to(tts.configs.device)
|
||||
all_bert_features = torch.cat([prompt_result.bert_features, target_result.bert_features], dim=1).to(
|
||||
dtype=tts.precision, device=tts.configs.device
|
||||
)
|
||||
_sync_device(device)
|
||||
tensorize_ms = (time.perf_counter() - tensorize_start) * 1000.0
|
||||
|
||||
_sync_device(device)
|
||||
prepare_profile = {
|
||||
"prompt_text_features_ms": prompt_text_features_ms,
|
||||
"text_features_ms": text_features_ms,
|
||||
"prompt_text_bert_wait_ms": float(prompt_text_profile.get("bert_wait_ms", 0.0)),
|
||||
"prompt_text_bert_forward_ms": float(prompt_text_profile.get("bert_forward_ms", 0.0)),
|
||||
"prompt_text_bert_tokenize_ms": float(prompt_text_profile.get("bert_tokenize_ms", 0.0)),
|
||||
"prompt_text_bert_scatter_ms": float(prompt_text_profile.get("bert_scatter_ms", 0.0)),
|
||||
"prompt_text_bert_calls": float(prompt_text_profile.get("bert_calls", 0.0)),
|
||||
"prompt_text_bert_stage_slots": float(prompt_text_profile.get("bert_stage_slots", 0.0)),
|
||||
"prompt_text_bert_stage_inflight_peak": float(prompt_text_profile.get("bert_stage_inflight_peak", 0.0)),
|
||||
"prompt_text_bert_batch_size_peak": float(prompt_text_profile.get("bert_batch_size_peak", 0.0)),
|
||||
"prompt_text_bert_batch_tokens_peak": float(prompt_text_profile.get("bert_batch_tokens_peak", 0.0)),
|
||||
"prompt_text_parallel_future_wait_ms": float(prompt_text_profile.get("parallel_future_wait_ms", 0.0)),
|
||||
"text_bert_wait_ms": float(text_features_profile.get("bert_wait_ms", 0.0)),
|
||||
"text_bert_forward_ms": float(text_features_profile.get("bert_forward_ms", 0.0)),
|
||||
"text_bert_tokenize_ms": float(text_features_profile.get("bert_tokenize_ms", 0.0)),
|
||||
"text_bert_scatter_ms": float(text_features_profile.get("bert_scatter_ms", 0.0)),
|
||||
"text_bert_calls": float(text_features_profile.get("bert_calls", 0.0)),
|
||||
"text_bert_stage_slots": float(text_features_profile.get("bert_stage_slots", 0.0)),
|
||||
"text_bert_stage_inflight_peak": float(text_features_profile.get("bert_stage_inflight_peak", 0.0)),
|
||||
"text_bert_batch_size_peak": float(text_features_profile.get("bert_batch_size_peak", 0.0)),
|
||||
"text_bert_batch_tokens_peak": float(text_features_profile.get("bert_batch_tokens_peak", 0.0)),
|
||||
"text_feature_pair_ms": text_feature_pair_ms,
|
||||
"prompt_text_features_ms": float(prompt_result.total_ms),
|
||||
"text_features_ms": float(target_result.total_ms),
|
||||
"prompt_text_cpu_preprocess_ms": float(prompt_result.cpu_preprocess_ms),
|
||||
"text_cpu_preprocess_ms": float(target_result.cpu_preprocess_ms),
|
||||
"prompt_text_bert_wait_ms": float(prompt_result.profile.get("bert_wait_ms", 0.0)),
|
||||
"prompt_text_bert_admission_wait_ms": float(prompt_result.profile.get("bert_admission_wait_ms", 0.0)),
|
||||
"prompt_text_bert_queue_wait_ms": float(prompt_result.profile.get("bert_queue_wait_ms", 0.0)),
|
||||
"prompt_text_bert_batch_collect_wait_ms": float(prompt_result.profile.get("bert_batch_collect_wait_ms", 0.0)),
|
||||
"prompt_text_bert_forward_ms": float(prompt_result.profile.get("bert_forward_ms", 0.0)),
|
||||
"prompt_text_bert_tokenize_ms": float(prompt_result.profile.get("bert_tokenize_ms", 0.0)),
|
||||
"prompt_text_bert_scatter_ms": float(prompt_result.profile.get("bert_scatter_ms", 0.0)),
|
||||
"prompt_text_bert_calls": float(prompt_result.profile.get("bert_calls", 0.0)),
|
||||
"prompt_text_bert_stage_slots": float(prompt_result.profile.get("bert_stage_slots", 0.0)),
|
||||
"prompt_text_bert_stage_inflight_peak": float(prompt_result.profile.get("bert_stage_inflight_peak", 0.0)),
|
||||
"prompt_text_bert_batch_size_peak": float(prompt_result.profile.get("bert_batch_size_peak", 0.0)),
|
||||
"prompt_text_bert_batch_tokens_peak": float(prompt_result.profile.get("bert_batch_tokens_peak", 0.0)),
|
||||
"prompt_text_bert_pending_depth_on_enqueue_peak": float(
|
||||
prompt_result.profile.get("bert_pending_depth_on_enqueue_peak", 0.0)
|
||||
),
|
||||
"prompt_text_bert_pending_depth_on_collect_peak": float(
|
||||
prompt_result.profile.get("bert_pending_depth_on_collect_peak", 0.0)
|
||||
),
|
||||
"prompt_text_bert_high_pressure_mode_peak": float(
|
||||
prompt_result.profile.get("bert_high_pressure_mode_peak", 0.0)
|
||||
),
|
||||
"prompt_text_bert_batch_window_ms": float(prompt_result.profile.get("bert_batch_window_ms", 0.0)),
|
||||
"prompt_text_parallel_future_wait_ms": 0.0,
|
||||
"prompt_text_parallel_future_executor_queue_ms": 0.0,
|
||||
"prompt_text_parallel_future_run_ms": float(prompt_result.total_ms),
|
||||
"prompt_text_parallel_future_finish_after_submit_ms": float(prompt_result.total_ms),
|
||||
"prompt_text_parallel_future_queue_tail_after_target_ms": 0.0,
|
||||
"prompt_text_parallel_future_run_tail_after_target_ms": 0.0,
|
||||
"text_bert_wait_ms": float(target_result.profile.get("bert_wait_ms", 0.0)),
|
||||
"text_bert_admission_wait_ms": float(target_result.profile.get("bert_admission_wait_ms", 0.0)),
|
||||
"text_bert_queue_wait_ms": float(target_result.profile.get("bert_queue_wait_ms", 0.0)),
|
||||
"text_bert_batch_collect_wait_ms": float(target_result.profile.get("bert_batch_collect_wait_ms", 0.0)),
|
||||
"text_bert_forward_ms": float(target_result.profile.get("bert_forward_ms", 0.0)),
|
||||
"text_bert_tokenize_ms": float(target_result.profile.get("bert_tokenize_ms", 0.0)),
|
||||
"text_bert_scatter_ms": float(target_result.profile.get("bert_scatter_ms", 0.0)),
|
||||
"text_bert_calls": float(target_result.profile.get("bert_calls", 0.0)),
|
||||
"text_bert_stage_slots": float(target_result.profile.get("bert_stage_slots", 0.0)),
|
||||
"text_bert_stage_inflight_peak": float(target_result.profile.get("bert_stage_inflight_peak", 0.0)),
|
||||
"text_bert_batch_size_peak": float(target_result.profile.get("bert_batch_size_peak", 0.0)),
|
||||
"text_bert_batch_tokens_peak": float(target_result.profile.get("bert_batch_tokens_peak", 0.0)),
|
||||
"text_bert_pending_depth_on_enqueue_peak": float(
|
||||
target_result.profile.get("bert_pending_depth_on_enqueue_peak", 0.0)
|
||||
),
|
||||
"text_bert_pending_depth_on_collect_peak": float(
|
||||
target_result.profile.get("bert_pending_depth_on_collect_peak", 0.0)
|
||||
),
|
||||
"text_bert_high_pressure_mode_peak": float(target_result.profile.get("bert_high_pressure_mode_peak", 0.0)),
|
||||
"text_bert_batch_window_ms": float(target_result.profile.get("bert_batch_window_ms", 0.0)),
|
||||
"text_feature_pair_ms": float(max(prompt_result.total_ms, target_result.total_ms)),
|
||||
"text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)),
|
||||
"audio_load_ms": audio_load_ms,
|
||||
"audio_stage_wait_ms": float(bundle_profile.get("audio_stage_wait_ms", 0.0)),
|
||||
@ -233,6 +263,8 @@ def prepare_request_state(
|
||||
"total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0,
|
||||
"wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0,
|
||||
}
|
||||
if profile_overrides:
|
||||
prepare_profile.update({key: float(value) for key, value in profile_overrides.items()})
|
||||
return T2SRequestState(
|
||||
request_id=spec.request_id,
|
||||
ref_audio_path=spec.ref_audio_path,
|
||||
@ -240,8 +272,8 @@ def prepare_request_state(
|
||||
prompt_lang=spec.prompt_lang,
|
||||
text=text,
|
||||
text_lang=spec.text_lang,
|
||||
norm_prompt_text=prompt_norm_text,
|
||||
norm_text=norm_text,
|
||||
norm_prompt_text=prompt_result.norm_text,
|
||||
norm_text=target_result.norm_text,
|
||||
phones=phones_tensor,
|
||||
prompt_phones=prompt_phones_tensor,
|
||||
all_phones=all_phones,
|
||||
@ -260,6 +292,33 @@ def prepare_request_state(
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def prepare_request_state(
|
||||
tts: Any,
|
||||
spec: SchedulerRequestSpec,
|
||||
) -> T2SRequestState:
|
||||
prepare_start = time.perf_counter()
|
||||
prepare_sync_start = time.perf_counter()
|
||||
prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang)
|
||||
text = spec.text.strip("\n")
|
||||
prompt_result = prepare_text_features(tts, prompt_text, spec.prompt_lang)
|
||||
target_result = prepare_text_features(tts, text, spec.text_lang)
|
||||
if target_result.phones is None:
|
||||
raise ValueError(f"{spec.request_id} text preprocessing returned no phones")
|
||||
ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path))
|
||||
return build_request_state_from_parts(
|
||||
tts=tts,
|
||||
spec=spec,
|
||||
prompt_text=prompt_text,
|
||||
text=text,
|
||||
prompt_result=prompt_result,
|
||||
target_result=target_result,
|
||||
ref_audio_bundle=ref_audio_bundle,
|
||||
prepare_start=prepare_start,
|
||||
prepare_sync_start=prepare_sync_start,
|
||||
)
|
||||
|
||||
|
||||
def _left_pad_hidden(hidden: torch.Tensor, target_len: int) -> torch.Tensor:
|
||||
if hidden.shape[0] >= target_len:
|
||||
return hidden
|
||||
@ -417,7 +476,8 @@ def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SAct
|
||||
decode_attn_mask=None,
|
||||
k_cache=None,
|
||||
v_cache=None,
|
||||
step_idx=0,
|
||||
kv_lens=None,
|
||||
step_indices=torch.zeros((len(states),), dtype=torch.long, device=device),
|
||||
prefill_done=False,
|
||||
)
|
||||
|
||||
@ -433,6 +493,64 @@ def build_next_xy_pos(model: Any, y_sequences: Sequence[torch.LongTensor]) -> to
|
||||
)
|
||||
|
||||
|
||||
def _compact_cache_to_kv_lens(
|
||||
cache: torch.Tensor,
|
||||
kv_lens: torch.LongTensor,
|
||||
) -> torch.Tensor:
|
||||
target_len = int(kv_lens.max().item())
|
||||
if cache.shape[1] == target_len and torch.all(kv_lens == target_len).item():
|
||||
return cache
|
||||
compacted = cache.new_zeros((cache.shape[0], target_len, cache.shape[2]))
|
||||
for batch_index, kv_len in enumerate(kv_lens.tolist()):
|
||||
if kv_len <= 0:
|
||||
continue
|
||||
compacted[batch_index, -kv_len:, :] = cache[batch_index, -kv_len:, :]
|
||||
return compacted
|
||||
|
||||
|
||||
def _compact_decode_mask_to_kv_lens(
|
||||
decode_attn_mask: Optional[torch.Tensor],
|
||||
kv_lens: torch.LongTensor,
|
||||
) -> Optional[torch.Tensor]:
|
||||
target_len = int(kv_lens.max().item()) + 1
|
||||
if decode_attn_mask is None:
|
||||
return None
|
||||
if decode_attn_mask.shape[-1] == target_len and torch.all(kv_lens + 1 == target_len).item():
|
||||
return decode_attn_mask
|
||||
compacted = torch.ones(
|
||||
(decode_attn_mask.shape[0], 1, 1, target_len),
|
||||
dtype=decode_attn_mask.dtype,
|
||||
device=decode_attn_mask.device,
|
||||
)
|
||||
for batch_index, kv_len in enumerate(kv_lens.tolist()):
|
||||
current_len = kv_len + 1
|
||||
compacted[batch_index, :, :, -current_len:] = decode_attn_mask[batch_index, :, :, -current_len:]
|
||||
if not compacted.any().item():
|
||||
return None
|
||||
return compacted
|
||||
|
||||
|
||||
def _advance_decode_mask(
|
||||
decode_attn_mask: Optional[torch.Tensor],
|
||||
kv_lens: torch.LongTensor,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if decode_attn_mask is None:
|
||||
return None
|
||||
target_len = int(kv_lens.max().item()) + 2
|
||||
advanced = torch.zeros(
|
||||
(decode_attn_mask.shape[0], 1, 1, target_len),
|
||||
dtype=decode_attn_mask.dtype,
|
||||
device=decode_attn_mask.device,
|
||||
)
|
||||
for batch_index, kv_len in enumerate(kv_lens.tolist()):
|
||||
current_len = kv_len + 1
|
||||
next_mask = F.pad(decode_attn_mask[batch_index : batch_index + 1, :, :, -current_len:], (0, 1), value=False)
|
||||
advanced[batch_index : batch_index + 1, :, :, -next_mask.shape[-1] :] = next_mask
|
||||
if not advanced.any().item():
|
||||
return None
|
||||
return advanced
|
||||
|
||||
|
||||
def _sample_per_request(
|
||||
model: Any,
|
||||
active_batch: T2SActiveBatch,
|
||||
@ -443,16 +561,15 @@ def _sample_per_request(
|
||||
keep_indices: List[int] = []
|
||||
updated_sequences: List[torch.LongTensor] = []
|
||||
|
||||
step_idx = active_batch.step_idx
|
||||
sampling_keys = [
|
||||
_sampling_group_key(
|
||||
top_k=state.top_k,
|
||||
top_p=state.top_p,
|
||||
temperature=state.temperature,
|
||||
repetition_penalty=state.repetition_penalty,
|
||||
trim_eos=False,
|
||||
trim_eos=int(active_batch.step_indices[batch_index].item()) < 11,
|
||||
)
|
||||
for state in active_batch.states
|
||||
for batch_index, state in enumerate(active_batch.states)
|
||||
]
|
||||
sampled_items, argmax_tokens = _batched_sample_by_group(
|
||||
logits=logits,
|
||||
@ -460,6 +577,7 @@ def _sample_per_request(
|
||||
sampling_keys=sampling_keys,
|
||||
)
|
||||
for batch_index, state in enumerate(active_batch.states):
|
||||
step_index = int(active_batch.step_indices[batch_index].item())
|
||||
current_history = active_batch.y_sequences[batch_index]
|
||||
sampled = sampled_items[batch_index]
|
||||
sampled_token = int(sampled[0, 0].item())
|
||||
@ -469,7 +587,7 @@ def _sample_per_request(
|
||||
finish_reason: Optional[str] = None
|
||||
if state.early_stop_num != -1 and (new_history.shape[0] - int(active_batch.prefix_lens[batch_index].item())) > state.early_stop_num:
|
||||
finish_reason = "early_stop"
|
||||
elif step_idx + 1 >= max_steps:
|
||||
elif step_index + 1 >= max_steps:
|
||||
finish_reason = "max_step"
|
||||
elif sampled_token == model.EOS:
|
||||
finish_reason = "eos_sample"
|
||||
@ -482,7 +600,7 @@ def _sample_per_request(
|
||||
T2SFinishedItem(
|
||||
request_id=state.request_id,
|
||||
semantic_tokens=new_history[prefix_len:-1].clone(),
|
||||
finish_idx=step_idx,
|
||||
finish_idx=step_index,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
)
|
||||
@ -493,30 +611,48 @@ def _sample_per_request(
|
||||
return finished_items, keep_indices, updated_sequences
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_one_step(
|
||||
model: Any,
|
||||
active_batch: T2SActiveBatch,
|
||||
max_steps: int,
|
||||
) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]:
|
||||
if not active_batch.prefill_done:
|
||||
was_prefill = not active_batch.prefill_done
|
||||
if was_prefill:
|
||||
if active_batch.prefill_attn_mask is None or active_batch.key_padding_mask is None:
|
||||
raise ValueError("prefill 阶段缺少必要 mask")
|
||||
xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt(
|
||||
active_batch.xy_pos, active_batch.prefill_attn_mask, None
|
||||
)
|
||||
active_batch.kv_lens = active_batch.x_lens + active_batch.prefix_lens
|
||||
active_batch.decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False)
|
||||
if active_batch.k_cache is None or active_batch.v_cache is None or active_batch.kv_lens is None:
|
||||
raise ValueError("prefill 阶段未生成完整 KV cache")
|
||||
active_batch.k_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.k_cache]
|
||||
active_batch.v_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.v_cache]
|
||||
active_batch.decode_attn_mask = _compact_decode_mask_to_kv_lens(active_batch.decode_attn_mask, active_batch.kv_lens)
|
||||
active_batch.x = None
|
||||
active_batch.x_lens = None
|
||||
active_batch.key_padding_mask = None
|
||||
active_batch.prefill_attn_mask = None
|
||||
active_batch.prefill_done = True
|
||||
else:
|
||||
if active_batch.k_cache is None or active_batch.v_cache is None or active_batch.kv_lens is None:
|
||||
raise ValueError("decode 阶段缺少 KV cache")
|
||||
batched_decode_attn_mask = None
|
||||
if active_batch.decode_attn_mask is not None:
|
||||
batched_decode_attn_mask = _materialize_decode_mask_for_active_batch(active_batch)
|
||||
if not batched_decode_attn_mask.any().item():
|
||||
batched_decode_attn_mask = None
|
||||
xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token(
|
||||
active_batch.xy_pos,
|
||||
active_batch.k_cache,
|
||||
active_batch.v_cache,
|
||||
active_batch.decode_attn_mask,
|
||||
batched_decode_attn_mask,
|
||||
)
|
||||
if active_batch.decode_attn_mask is not None:
|
||||
active_batch.decode_attn_mask = F.pad(active_batch.decode_attn_mask, (0, 1), value=False)
|
||||
active_batch.decode_attn_mask = _advance_decode_mask(active_batch.decode_attn_mask, active_batch.kv_lens)
|
||||
|
||||
logits = model.ar_predict_layer(xy_dec[:, -1])
|
||||
if active_batch.step_idx < 11:
|
||||
logits = logits[:, :-1]
|
||||
|
||||
finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps)
|
||||
if len(keep_indices) == 0:
|
||||
@ -528,16 +664,32 @@ def decode_one_step(
|
||||
active_batch.states = [active_batch.states[i] for i in keep_indices]
|
||||
active_batch.y_sequences = updated_sequences
|
||||
active_batch.prefix_lens = torch.index_select(active_batch.prefix_lens, dim=0, index=keep_tensor)
|
||||
next_step_indices = torch.index_select(active_batch.step_indices, dim=0, index=keep_tensor)
|
||||
next_kv_lens = None if active_batch.kv_lens is None else torch.index_select(active_batch.kv_lens, dim=0, index=keep_tensor)
|
||||
active_batch.step_indices = next_step_indices + 1
|
||||
if not was_prefill:
|
||||
if next_kv_lens is not None:
|
||||
active_batch.kv_lens = next_kv_lens + 1
|
||||
else:
|
||||
active_batch.kv_lens = next_kv_lens
|
||||
|
||||
if active_batch.decode_attn_mask is not None:
|
||||
active_batch.decode_attn_mask = torch.index_select(active_batch.decode_attn_mask, dim=0, index=keep_tensor)
|
||||
if not active_batch.decode_attn_mask.any().item():
|
||||
active_batch.decode_attn_mask = None
|
||||
if active_batch.k_cache is not None and active_batch.v_cache is not None:
|
||||
for cache_index in range(len(active_batch.k_cache)):
|
||||
active_batch.k_cache[cache_index] = torch.index_select(active_batch.k_cache[cache_index], dim=0, index=keep_tensor)
|
||||
active_batch.v_cache[cache_index] = torch.index_select(active_batch.v_cache[cache_index], dim=0, index=keep_tensor)
|
||||
if active_batch.kv_lens is not None:
|
||||
active_batch.k_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.k_cache]
|
||||
active_batch.v_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.v_cache]
|
||||
active_batch.decode_attn_mask = _compact_decode_mask_to_kv_lens(
|
||||
active_batch.decode_attn_mask,
|
||||
active_batch.kv_lens,
|
||||
)
|
||||
|
||||
active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences)
|
||||
active_batch.step_idx += 1
|
||||
return active_batch, finished_items
|
||||
|
||||
|
||||
@ -583,6 +735,126 @@ def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) ->
|
||||
)
|
||||
|
||||
|
||||
def _materialize_decode_mask_for_active_batch(
|
||||
active_batch: T2SActiveBatch,
|
||||
target_mask_len: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
if active_batch.k_cache is None or active_batch.kv_lens is None:
|
||||
raise ValueError("active batch 缺少 KV cache 或 kv_lens")
|
||||
current_mask_len = active_batch.k_cache[0].shape[1] + 1
|
||||
if target_mask_len is None:
|
||||
target_mask_len = current_mask_len
|
||||
if active_batch.decode_attn_mask is None:
|
||||
mask = torch.zeros(
|
||||
(len(active_batch.request_ids), 1, 1, current_mask_len),
|
||||
dtype=torch.bool,
|
||||
device=active_batch.k_cache[0].device,
|
||||
)
|
||||
else:
|
||||
rows: List[torch.Tensor] = []
|
||||
for batch_index, kv_len in enumerate(active_batch.kv_lens.tolist()):
|
||||
row_len = kv_len + 1
|
||||
row_mask = _fit_decode_mask_length(
|
||||
active_batch.decode_attn_mask[batch_index : batch_index + 1],
|
||||
row_len,
|
||||
)
|
||||
rows.append(_pad_decode_mask_left(row_mask, target_mask_len))
|
||||
mask = torch.cat(rows, dim=0)
|
||||
if target_mask_len != current_mask_len and active_batch.decode_attn_mask is None:
|
||||
mask = _pad_decode_mask_left(mask, target_mask_len)
|
||||
return mask
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_prefill_active_batch(
|
||||
model: Any,
|
||||
states: Sequence[T2SRequestState],
|
||||
max_steps: int,
|
||||
) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]:
|
||||
if not states:
|
||||
return None, []
|
||||
active_batch = build_prefill_batch(model, states)
|
||||
return decode_one_step(model, active_batch, max_steps=max_steps)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def merge_active_batches(
|
||||
model: Any,
|
||||
left_batch: Optional[T2SActiveBatch],
|
||||
right_batch: Optional[T2SActiveBatch],
|
||||
) -> Optional[T2SActiveBatch]:
|
||||
if left_batch is None:
|
||||
return right_batch
|
||||
if right_batch is None:
|
||||
return left_batch
|
||||
if not left_batch.prefill_done or not right_batch.prefill_done:
|
||||
raise ValueError("只有 prefill 完成后的 active batch 才能 merge")
|
||||
if left_batch.k_cache is None or left_batch.v_cache is None or right_batch.k_cache is None or right_batch.v_cache is None:
|
||||
raise ValueError("merge active batch 时缺少 KV cache")
|
||||
|
||||
left_kv_len = int(left_batch.k_cache[0].shape[1])
|
||||
right_kv_len = int(right_batch.k_cache[0].shape[1])
|
||||
merged_kv_len = max(left_kv_len, right_kv_len)
|
||||
merged_mask_len = merged_kv_len + 1
|
||||
|
||||
merged_k_cache: List[torch.Tensor] = []
|
||||
merged_v_cache: List[torch.Tensor] = []
|
||||
for layer_index in range(len(left_batch.k_cache)):
|
||||
merged_k_cache.append(
|
||||
torch.cat(
|
||||
[
|
||||
_pad_cache_left(left_batch.k_cache[layer_index], merged_kv_len),
|
||||
_pad_cache_left(right_batch.k_cache[layer_index], merged_kv_len),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
merged_v_cache.append(
|
||||
torch.cat(
|
||||
[
|
||||
_pad_cache_left(left_batch.v_cache[layer_index], merged_kv_len),
|
||||
_pad_cache_left(right_batch.v_cache[layer_index], merged_kv_len),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
|
||||
merged_decode_attn_mask = torch.cat(
|
||||
[
|
||||
_materialize_decode_mask_for_active_batch(left_batch, merged_mask_len),
|
||||
_materialize_decode_mask_for_active_batch(right_batch, merged_mask_len),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
merged_request_ids = list(left_batch.request_ids) + list(right_batch.request_ids)
|
||||
merged_states = list(left_batch.states) + list(right_batch.states)
|
||||
merged_y_sequences = list(left_batch.y_sequences) + list(right_batch.y_sequences)
|
||||
merged_prefix_lens = torch.cat([left_batch.prefix_lens, right_batch.prefix_lens], dim=0)
|
||||
if left_batch.kv_lens is None or right_batch.kv_lens is None:
|
||||
raise ValueError("merge active batch 时缺少 kv_lens")
|
||||
merged_kv_lens = torch.cat([left_batch.kv_lens, right_batch.kv_lens], dim=0)
|
||||
merged_decode_attn_mask = _compact_decode_mask_to_kv_lens(merged_decode_attn_mask, merged_kv_lens)
|
||||
merged_step_indices = torch.cat([left_batch.step_indices, right_batch.step_indices], dim=0)
|
||||
|
||||
return T2SActiveBatch(
|
||||
request_ids=merged_request_ids,
|
||||
states=merged_states,
|
||||
x=None,
|
||||
x_lens=None,
|
||||
y_sequences=merged_y_sequences,
|
||||
prefix_lens=merged_prefix_lens,
|
||||
xy_pos=build_next_xy_pos(model, merged_y_sequences),
|
||||
key_padding_mask=None,
|
||||
prefill_attn_mask=None,
|
||||
decode_attn_mask=merged_decode_attn_mask,
|
||||
k_cache=merged_k_cache,
|
||||
v_cache=merged_v_cache,
|
||||
kv_lens=merged_kv_lens,
|
||||
step_indices=merged_step_indices,
|
||||
prefill_done=True,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_prefill_step(
|
||||
model: Any,
|
||||
@ -804,29 +1076,24 @@ def run_scheduler_continuous(
|
||||
max_steps: int,
|
||||
) -> List[T2SFinishedItem]:
|
||||
pending = sorted(states, key=lambda item: (item.ready_step, item.request_id))
|
||||
running_requests: List[T2SRunningRequest] = []
|
||||
active_batch: Optional[T2SActiveBatch] = None
|
||||
finished: List[T2SFinishedItem] = []
|
||||
current_tick = 0
|
||||
|
||||
while pending or running_requests:
|
||||
while pending or active_batch is not None:
|
||||
admitted: List[T2SRequestState] = []
|
||||
while pending and pending[0].ready_step <= current_tick:
|
||||
admitted.append(pending.pop(0))
|
||||
|
||||
admitted_running, admitted_finished = run_prefill_step(model, admitted, max_steps=max_steps)
|
||||
admitted_active_batch, admitted_finished = run_prefill_active_batch(model, admitted, max_steps=max_steps)
|
||||
finished.extend(admitted_finished)
|
||||
active_batch = merge_active_batches(model, active_batch, admitted_active_batch)
|
||||
|
||||
if running_requests:
|
||||
running_requests, step_finished = run_decode_step_for_running(
|
||||
model,
|
||||
running_requests,
|
||||
max_steps=max_steps,
|
||||
)
|
||||
if active_batch is not None:
|
||||
active_batch, step_finished = decode_one_step(model, active_batch, max_steps=max_steps)
|
||||
finished.extend(step_finished)
|
||||
|
||||
running_requests.extend(admitted_running)
|
||||
|
||||
if not running_requests and pending:
|
||||
if active_batch is None and pending:
|
||||
current_tick = max(current_tick + 1, pending[0].ready_step)
|
||||
continue
|
||||
|
||||
|
||||
100
GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py
Normal file
100
GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py
Normal file
@ -0,0 +1,100 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
|
||||
from text.LangSegmenter import LangSegmenter
|
||||
from text import cleaned_text_to_sequence
|
||||
from text.cleaner import clean_text
|
||||
|
||||
|
||||
PreparedTextSegmentPayload = Dict[str, object]
|
||||
|
||||
|
||||
def split_text_by_language(text: str, language: str) -> Tuple[List[str], List[str]]:
|
||||
textlist: List[str] = []
|
||||
langlist: List[str] = []
|
||||
if language == "all_zh":
|
||||
for tmp in LangSegmenter.getTexts(text, "zh"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_yue":
|
||||
for tmp in LangSegmenter.getTexts(text, "zh"):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ja":
|
||||
for tmp in LangSegmenter.getTexts(text, "ja"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ko":
|
||||
for tmp in LangSegmenter.getTexts(text, "ko"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "en":
|
||||
langlist.append("en")
|
||||
textlist.append(text)
|
||||
elif language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "auto_yue":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or (
|
||||
tmp["lang"] != "en" and langlist[-1] != "en"
|
||||
)
|
||||
if same_group:
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
return textlist, langlist
|
||||
|
||||
|
||||
def clean_text_segment(text: str, language: str, version: str) -> Tuple[List[int], Optional[List[int]], str]:
|
||||
normalized_language = language.replace("all_", "")
|
||||
phones, word2ph, norm_text = clean_text(text, normalized_language, version)
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return list(phones), None if word2ph is None else list(word2ph), str(norm_text)
|
||||
|
||||
|
||||
def preprocess_text_segments_payload(
|
||||
text: str,
|
||||
language: str,
|
||||
version: str,
|
||||
final: bool = False,
|
||||
) -> List[PreparedTextSegmentPayload]:
|
||||
text = re.sub(r" {2,}", " ", text)
|
||||
textlist, langlist = split_text_by_language(text, language)
|
||||
payloads: List[PreparedTextSegmentPayload] = []
|
||||
total_phones_len = 0
|
||||
for segment_text, segment_lang in zip(textlist, langlist):
|
||||
phones, word2ph, norm_text = clean_text_segment(segment_text, segment_lang, version)
|
||||
payloads.append(
|
||||
{
|
||||
"language": segment_lang.replace("all_", ""),
|
||||
"phones": phones,
|
||||
"word2ph": word2ph,
|
||||
"norm_text": norm_text,
|
||||
}
|
||||
)
|
||||
total_phones_len += len(phones)
|
||||
|
||||
if not final and total_phones_len < 6:
|
||||
return preprocess_text_segments_payload("." + text, language, version, final=True)
|
||||
|
||||
return payloads
|
||||
@ -2,6 +2,7 @@ import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -1038,6 +1039,67 @@ class SynthesizerTrn(nn.Module):
|
||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
return o
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_batched_request_local(
|
||||
self,
|
||||
codes: torch.Tensor,
|
||||
code_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
refer_list: List[torch.Tensor],
|
||||
noise_scale: float = 0.5,
|
||||
speed: float = 1,
|
||||
sv_emb: torch.Tensor | None = None,
|
||||
):
|
||||
batch_size = int(codes.size(1))
|
||||
if batch_size <= 0:
|
||||
raise ValueError("decode_batched_request_local 收到空 batch")
|
||||
if len(refer_list) != batch_size:
|
||||
raise ValueError("refer_list 数量与 batch size 不一致")
|
||||
|
||||
refer_lengths = torch.LongTensor([int(item.size(2)) for item in refer_list]).to(codes.device)
|
||||
max_refer_len = int(refer_lengths.max().item())
|
||||
refer_batch = torch.zeros(
|
||||
(batch_size, int(refer_list[0].size(1)), max_refer_len),
|
||||
dtype=refer_list[0].dtype,
|
||||
device=codes.device,
|
||||
)
|
||||
for batch_index, refer in enumerate(refer_list):
|
||||
refer_batch[batch_index, :, : int(refer.size(2))] = refer.squeeze(0)
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, max_refer_len), 1).to(refer_batch.dtype)
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(refer_batch * refer_mask, refer_mask)
|
||||
else:
|
||||
ge = self.ref_enc(refer_batch[:, :704] * refer_mask, refer_mask)
|
||||
if self.is_v2pro:
|
||||
if sv_emb is None:
|
||||
raise ValueError("v2Pro batched request-local synthesis 缺少 sv_emb")
|
||||
ge = ge + self.sv_emb(sv_emb).unsqueeze(-1)
|
||||
ge = self.prelu(ge)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")
|
||||
y_lengths = code_lengths.to(device=codes.device, dtype=torch.long) * 2
|
||||
text_lengths = text_lengths.to(device=text.device, dtype=torch.long)
|
||||
x, m_p, logs_p, y_mask, _, _ = self.enc_p(
|
||||
quantized,
|
||||
y_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
|
||||
speed,
|
||||
)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
audio = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
upsample_factor = 1
|
||||
for up_layer in self.dec.ups:
|
||||
stride = up_layer.stride[0] if isinstance(up_layer.stride, tuple) else int(up_layer.stride)
|
||||
upsample_factor *= int(stride)
|
||||
audio_lengths = y_mask.squeeze(1).sum(dim=1).to(dtype=torch.long) * int(upsample_factor)
|
||||
return audio, audio_lengths
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_streaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
|
||||
|
||||
565
api_v3.py
565
api_v3.py
@ -107,6 +107,7 @@ import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Union
|
||||
@ -115,6 +116,10 @@ now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||
|
||||
from runtime_preload import preload_text_runtime_deps
|
||||
|
||||
preload_text_runtime_deps()
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import wave
|
||||
@ -128,14 +133,15 @@ import uvicorn
|
||||
from io import BytesIO
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import (
|
||||
SchedulerRequestSpec,
|
||||
T2SActiveBatch,
|
||||
T2SFinishedItem,
|
||||
T2SRunningRequest,
|
||||
T2SRequestState,
|
||||
prepare_request_state,
|
||||
run_decode_step_for_running,
|
||||
run_prefill_step,
|
||||
merge_active_batches,
|
||||
decode_one_step,
|
||||
run_prefill_active_batch,
|
||||
run_scheduler_continuous,
|
||||
)
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
||||
@ -238,39 +244,71 @@ class SchedulerPendingJob:
|
||||
request_id: str
|
||||
state: T2SRequestState
|
||||
done_event: threading.Event
|
||||
done_loop: asyncio.AbstractEventLoop | None
|
||||
done_future: asyncio.Future | None
|
||||
enqueue_time: float
|
||||
speed_factor: float
|
||||
sample_steps: int
|
||||
media_type: str
|
||||
prepare_ms: float = 0.0
|
||||
prepare_wall_ms: float = 0.0
|
||||
prepare_profile_total_ms: float = 0.0
|
||||
first_schedule_time: float | None = None
|
||||
prefill_ms: float = 0.0
|
||||
merge_ms: float = 0.0
|
||||
decode_ms: float = 0.0
|
||||
finalize_wait_ms: float = 0.0
|
||||
synth_ms: float = 0.0
|
||||
pack_ms: float = 0.0
|
||||
decode_steps: int = 0
|
||||
result_ready_time: float | None = None
|
||||
result: dict | None = None
|
||||
sample_rate: int | None = None
|
||||
audio_data: np.ndarray | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerFinalizeTask:
|
||||
request_id: str
|
||||
item: T2SFinishedItem
|
||||
enqueued_time: float
|
||||
|
||||
|
||||
class SchedulerDebugWorker:
|
||||
def __init__(self, tts: TTS, max_steps: int = 1500, micro_batch_wait_ms: int = 5):
|
||||
self.tts = tts
|
||||
self.max_steps = max_steps
|
||||
self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0
|
||||
self.prepare_coordinator = PrepareCoordinator(tts)
|
||||
self.condition = threading.Condition()
|
||||
self.prepare_inflight = 0
|
||||
self.prepare_peak_inflight = 0
|
||||
self.finalize_condition = threading.Condition()
|
||||
self.finalize_pending_tasks: deque[SchedulerFinalizeTask] = deque()
|
||||
self.finalize_pending_peak = 0
|
||||
self.finalize_inflight = 0
|
||||
self.finalize_inflight_peak = 0
|
||||
self.finalize_workers = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1)))
|
||||
self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower()
|
||||
self.finalize_batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16)))
|
||||
self.finalize_batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0)
|
||||
self.pending_jobs: List[SchedulerPendingJob] = []
|
||||
self.running_requests: List[T2SRunningRequest] = []
|
||||
self.active_batch: T2SActiveBatch | None = None
|
||||
self.job_map: dict[str, SchedulerPendingJob] = {}
|
||||
self.total_finished = 0
|
||||
self.total_submitted = 0
|
||||
self.worker_thread = threading.Thread(target=self._run_loop, name="t2s-scheduler-debug-worker", daemon=True)
|
||||
self.worker_thread.start()
|
||||
self.finalize_threads = [
|
||||
threading.Thread(
|
||||
target=self._run_finalize_loop,
|
||||
name=f"t2s-scheduler-finalize-{worker_index}",
|
||||
daemon=True,
|
||||
)
|
||||
for worker_index in range(self.finalize_workers)
|
||||
]
|
||||
for finalize_thread in self.finalize_threads:
|
||||
finalize_thread.start()
|
||||
|
||||
def _sync_device(self) -> None:
|
||||
try:
|
||||
@ -283,20 +321,7 @@ class SchedulerDebugWorker:
|
||||
pass
|
||||
|
||||
def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState:
|
||||
with self.condition:
|
||||
self.prepare_inflight += 1
|
||||
prepare_inflight_on_enter = self.prepare_inflight
|
||||
if self.prepare_inflight > self.prepare_peak_inflight:
|
||||
self.prepare_peak_inflight = self.prepare_inflight
|
||||
prepare_peak_inflight = self.prepare_peak_inflight
|
||||
try:
|
||||
state = prepare_request_state(self.tts, spec)
|
||||
state.prepare_profile["worker_prepare_inflight_on_enter"] = float(prepare_inflight_on_enter)
|
||||
state.prepare_profile["worker_prepare_peak_inflight"] = float(prepare_peak_inflight)
|
||||
return state
|
||||
finally:
|
||||
with self.condition:
|
||||
self.prepare_inflight = max(0, self.prepare_inflight - 1)
|
||||
raise RuntimeError("prepare_state sync path has been replaced by PrepareCoordinator")
|
||||
|
||||
def submit(
|
||||
self,
|
||||
@ -304,27 +329,47 @@ class SchedulerDebugWorker:
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
prepare_ms: float,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None = None,
|
||||
done_future: asyncio.Future | None = None,
|
||||
) -> SchedulerPendingJob:
|
||||
job = SchedulerPendingJob(
|
||||
request_id=state.request_id,
|
||||
state=state,
|
||||
done_event=threading.Event(),
|
||||
done_loop=done_loop,
|
||||
done_future=done_future,
|
||||
enqueue_time=time.perf_counter(),
|
||||
speed_factor=float(speed_factor),
|
||||
sample_steps=int(sample_steps),
|
||||
media_type=media_type,
|
||||
prepare_ms=float(prepare_ms),
|
||||
prepare_wall_ms=float(prepare_wall_ms),
|
||||
prepare_profile_total_ms=float(prepare_profile_total_ms),
|
||||
)
|
||||
with self.condition:
|
||||
self.pending_jobs.append(job)
|
||||
self.job_map[job.request_id] = job
|
||||
self.total_submitted += 1
|
||||
self.condition.notify_all()
|
||||
with self.finalize_condition:
|
||||
self.finalize_condition.notify_all()
|
||||
return job
|
||||
|
||||
async def prepare_state_async(self, spec: SchedulerRequestSpec) -> T2SRequestState:
|
||||
state, _, _ = await self.prepare_coordinator.prepare_state_profiled_async(spec, time.perf_counter())
|
||||
return state
|
||||
|
||||
async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]:
|
||||
return await asyncio.gather(*[self.prepare_state_async(spec) for spec in specs])
|
||||
|
||||
async def prepare_state_profiled_async(
|
||||
self,
|
||||
spec: SchedulerRequestSpec,
|
||||
prepare_submit_at: float,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
return await self.prepare_coordinator.prepare_state_profiled_async(spec, prepare_submit_at)
|
||||
|
||||
def _mark_prefill_started(self, jobs: List[SchedulerPendingJob], started_at: float) -> None:
|
||||
with self.condition:
|
||||
for job in jobs:
|
||||
@ -340,6 +385,14 @@ class SchedulerDebugWorker:
|
||||
if tracked_job is not None:
|
||||
tracked_job.prefill_ms += elapsed_ms
|
||||
|
||||
def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
|
||||
elapsed_ms = elapsed_s * 1000.0
|
||||
with self.condition:
|
||||
for request_id in request_ids:
|
||||
job = self.job_map.get(request_id)
|
||||
if job is not None:
|
||||
job.merge_ms += elapsed_ms
|
||||
|
||||
def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
|
||||
elapsed_ms = elapsed_s * 1000.0
|
||||
with self.condition:
|
||||
@ -349,16 +402,30 @@ class SchedulerDebugWorker:
|
||||
job.decode_ms += elapsed_ms
|
||||
job.decode_steps += 1
|
||||
|
||||
def _add_finalize_wait_ms(self, request_ids: List[str], elapsed_ms: float) -> None:
|
||||
with self.condition:
|
||||
for request_id in request_ids:
|
||||
job = self.job_map.get(request_id)
|
||||
if job is not None:
|
||||
job.finalize_wait_ms += elapsed_ms
|
||||
|
||||
def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]:
|
||||
semantic_tokens = item.semantic_tokens.unsqueeze(0).unsqueeze(0).to(self.tts.configs.device)
|
||||
phones = job.state.phones.unsqueeze(0).to(self.tts.configs.device)
|
||||
semantic_tokens = item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0).to(self.tts.configs.device)
|
||||
phones = job.state.phones.detach().clone().unsqueeze(0).to(self.tts.configs.device)
|
||||
prompt_semantic = job.state.prompt_semantic.detach().clone()
|
||||
prompt_phones = job.state.prompt_phones.detach().clone()
|
||||
refer_spec = (
|
||||
job.state.refer_spec[0].detach().clone(),
|
||||
None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(),
|
||||
)
|
||||
raw_audio = job.state.raw_audio.detach().clone()
|
||||
audio_fragment = self.tts.synthesize_audio_request_local(
|
||||
semantic_tokens=semantic_tokens,
|
||||
phones=phones,
|
||||
prompt_semantic=job.state.prompt_semantic,
|
||||
prompt_phones=job.state.prompt_phones,
|
||||
refer_spec=job.state.refer_spec,
|
||||
raw_audio=job.state.raw_audio,
|
||||
prompt_semantic=prompt_semantic,
|
||||
prompt_phones=prompt_phones,
|
||||
refer_spec=refer_spec,
|
||||
raw_audio=raw_audio,
|
||||
raw_sr=job.state.raw_sr,
|
||||
speed=float(job.speed_factor),
|
||||
sample_steps=int(job.sample_steps),
|
||||
@ -375,6 +442,11 @@ class SchedulerDebugWorker:
|
||||
)
|
||||
|
||||
def get_state(self) -> dict:
|
||||
with self.finalize_condition:
|
||||
finalize_pending = len(self.finalize_pending_tasks)
|
||||
finalize_pending_peak = self.finalize_pending_peak
|
||||
finalize_inflight = self.finalize_inflight
|
||||
finalize_inflight_peak = self.finalize_inflight_peak
|
||||
with self.condition:
|
||||
bert_stage = self.tts.prepare_bert_stage_limiter.snapshot()
|
||||
ref_audio_stage = self.tts.prepare_ref_audio_stage_limiter.snapshot()
|
||||
@ -388,12 +460,24 @@ class SchedulerDebugWorker:
|
||||
if self.tts.prepare_ref_semantic_batch_worker is None
|
||||
else self.tts.prepare_ref_semantic_batch_worker.snapshot()
|
||||
)
|
||||
prepare_coordinator_state = self.prepare_coordinator.snapshot()
|
||||
return {
|
||||
"pending_jobs": len(self.pending_jobs),
|
||||
"running_requests": len(self.running_requests),
|
||||
"prepare_inflight": self.prepare_inflight,
|
||||
"prepare_peak_inflight": self.prepare_peak_inflight,
|
||||
"running_requests": 0 if self.active_batch is None else len(self.active_batch.request_ids),
|
||||
"prepare_inflight": prepare_coordinator_state["inflight"],
|
||||
"prepare_peak_inflight": prepare_coordinator_state["peak_inflight"],
|
||||
"finalize_pending": finalize_pending,
|
||||
"finalize_pending_peak": finalize_pending_peak,
|
||||
"finalize_inflight": finalize_inflight,
|
||||
"finalize_inflight_peak": finalize_inflight_peak,
|
||||
"finalize_workers": self.finalize_workers,
|
||||
"finalize_mode": self.finalize_mode,
|
||||
"finalize_batch_max_items": self.finalize_batch_max_items,
|
||||
"finalize_batch_wait_ms": self.finalize_batch_wait_s * 1000.0,
|
||||
"prepare_request_executor_workers": 0,
|
||||
"prepare_text_cpu_workers": int(getattr(self.tts, "prepare_text_cpu_workers", 0)),
|
||||
"prepare_text_feature_workers": int(prepare_coordinator_state["text_feature_workers"]),
|
||||
"prepare_ref_audio_workers": int(prepare_coordinator_state["ref_audio_workers"]),
|
||||
"prepare_bert_stage": bert_stage,
|
||||
"prepare_bert_batch_worker": bert_batch_worker,
|
||||
"prepare_ref_audio_stage": ref_audio_stage,
|
||||
@ -405,59 +489,217 @@ class SchedulerDebugWorker:
|
||||
"micro_batch_wait_ms": int(self.micro_batch_wait_s * 1000),
|
||||
}
|
||||
|
||||
def _finalize_finished(self, items: List[T2SFinishedItem]) -> None:
|
||||
def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None:
|
||||
if not items:
|
||||
return
|
||||
jobs_to_finalize: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
|
||||
tasks: List[SchedulerFinalizeTask] = []
|
||||
enqueued_time = time.perf_counter()
|
||||
with self.condition:
|
||||
for item in items:
|
||||
job = self.job_map.get(item.request_id)
|
||||
if job is not None:
|
||||
jobs_to_finalize.append((job, item))
|
||||
tasks.append(
|
||||
SchedulerFinalizeTask(
|
||||
request_id=item.request_id,
|
||||
item=item,
|
||||
enqueued_time=enqueued_time,
|
||||
)
|
||||
)
|
||||
if not tasks:
|
||||
return
|
||||
with self.finalize_condition:
|
||||
self.finalize_pending_tasks.extend(tasks)
|
||||
if len(self.finalize_pending_tasks) > self.finalize_pending_peak:
|
||||
self.finalize_pending_peak = len(self.finalize_pending_tasks)
|
||||
self.finalize_condition.notify_all()
|
||||
|
||||
for job, item in jobs_to_finalize:
|
||||
@staticmethod
|
||||
def _finalize_batch_key(job: SchedulerPendingJob) -> tuple[float, int]:
|
||||
return (round(float(job.speed_factor), 6), int(job.sample_steps))
|
||||
|
||||
def _take_finalize_task_batch(self) -> List[SchedulerFinalizeTask]:
|
||||
with self.finalize_condition:
|
||||
while not self.finalize_pending_tasks:
|
||||
self.finalize_condition.wait()
|
||||
if self.finalize_mode == "after_t2s_drain":
|
||||
while not self._is_t2s_drained():
|
||||
self.finalize_condition.wait(timeout=self.micro_batch_wait_s)
|
||||
task = self.finalize_pending_tasks.popleft()
|
||||
selected_tasks = [task]
|
||||
batch_key = None
|
||||
with self.condition:
|
||||
first_job = self.job_map.get(task.request_id)
|
||||
if first_job is not None:
|
||||
batch_key = self._finalize_batch_key(first_job)
|
||||
batch_deadline = time.perf_counter() + self.finalize_batch_wait_s
|
||||
while len(selected_tasks) < self.finalize_batch_max_items:
|
||||
if batch_key is None:
|
||||
break
|
||||
matched_index = None
|
||||
for pending_index, pending_task in enumerate(self.finalize_pending_tasks):
|
||||
with self.condition:
|
||||
pending_job = self.job_map.get(pending_task.request_id)
|
||||
if pending_job is None:
|
||||
matched_index = pending_index
|
||||
break
|
||||
if self._finalize_batch_key(pending_job) == batch_key:
|
||||
matched_index = pending_index
|
||||
break
|
||||
if matched_index is not None:
|
||||
selected_tasks.append(self.finalize_pending_tasks[matched_index])
|
||||
del self.finalize_pending_tasks[matched_index]
|
||||
continue
|
||||
remaining = batch_deadline - time.perf_counter()
|
||||
if remaining <= 0:
|
||||
break
|
||||
self.finalize_condition.wait(timeout=remaining)
|
||||
self.finalize_inflight += len(selected_tasks)
|
||||
if self.finalize_inflight > self.finalize_inflight_peak:
|
||||
self.finalize_inflight_peak = self.finalize_inflight
|
||||
return selected_tasks
|
||||
|
||||
def _finalize_task_done(self, count: int) -> None:
|
||||
with self.finalize_condition:
|
||||
self.finalize_inflight = max(0, self.finalize_inflight - count)
|
||||
|
||||
def _is_t2s_drained(self) -> bool:
|
||||
with self.condition:
|
||||
return (
|
||||
self.active_batch is None
|
||||
and not self.pending_jobs
|
||||
and self.prepare_inflight <= 0
|
||||
)
|
||||
|
||||
def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None:
|
||||
finished_at = time.perf_counter()
|
||||
with self.condition:
|
||||
if self.job_map.get(item.request_id) is not job:
|
||||
return
|
||||
queue_wait_ms = 0.0
|
||||
if job.first_schedule_time is not None:
|
||||
queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0)
|
||||
worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0)
|
||||
worker_residual_ms = max(
|
||||
0.0,
|
||||
worker_total_ms
|
||||
- queue_wait_ms
|
||||
- job.prefill_ms
|
||||
- job.merge_ms
|
||||
- job.decode_ms
|
||||
- job.finalize_wait_ms
|
||||
- job.synth_ms,
|
||||
)
|
||||
worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms)
|
||||
job.sample_rate = int(sample_rate)
|
||||
job.audio_data = audio_data
|
||||
job.result_ready_time = finished_at
|
||||
prepare_profile = dict(job.state.prepare_profile)
|
||||
job.result = {
|
||||
"request_id": item.request_id,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
"finish_idx": int(item.finish_idx),
|
||||
"finish_reason": item.finish_reason,
|
||||
"prepare_ms": job.prepare_wall_ms,
|
||||
"prepare_wall_ms": job.prepare_wall_ms,
|
||||
"prepare_profile_total_ms": job.prepare_profile_total_ms,
|
||||
"prepare_profile": prepare_profile,
|
||||
"queue_wait_ms": queue_wait_ms,
|
||||
"prefill_ms": job.prefill_ms,
|
||||
"merge_ms": job.merge_ms,
|
||||
"decode_ms": job.decode_ms,
|
||||
"finalize_wait_ms": job.finalize_wait_ms,
|
||||
"synth_ms": job.synth_ms,
|
||||
"worker_residual_ms": worker_residual_ms,
|
||||
"worker_other_ms": worker_other_ms,
|
||||
"worker_total_ms": worker_total_ms,
|
||||
"decode_steps": int(job.decode_steps),
|
||||
"sample_rate": int(sample_rate),
|
||||
"media_type": job.media_type,
|
||||
}
|
||||
job.done_event.set()
|
||||
self._notify_done_future(job)
|
||||
self.job_map.pop(item.request_id, None)
|
||||
self.total_finished += 1
|
||||
|
||||
def _synthesize_finished_audio_batch(
|
||||
self,
|
||||
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]],
|
||||
) -> List[tuple[int, np.ndarray]]:
|
||||
semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items]
|
||||
phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items]
|
||||
refer_specs = []
|
||||
speeds = []
|
||||
sample_steps_list = []
|
||||
for job, _ in jobs_and_items:
|
||||
refer_specs.append(
|
||||
(
|
||||
job.state.refer_spec[0].detach().clone(),
|
||||
None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(),
|
||||
)
|
||||
)
|
||||
speeds.append(float(job.speed_factor))
|
||||
sample_steps_list.append(int(job.sample_steps))
|
||||
audio_fragments = self.tts.synthesize_audio_requests_local_batched(
|
||||
semantic_tokens_list=semantic_tokens_list,
|
||||
phones_list=phones_list,
|
||||
refer_specs=refer_specs,
|
||||
speeds=speeds,
|
||||
sample_steps_list=sample_steps_list,
|
||||
)
|
||||
output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"]
|
||||
results: List[tuple[int, np.ndarray]] = []
|
||||
for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments):
|
||||
results.append(
|
||||
self.tts.audio_postprocess(
|
||||
audio=[[audio_fragment]],
|
||||
sr=int(output_sr),
|
||||
batch_index_list=None,
|
||||
speed_factor=float(job.speed_factor),
|
||||
split_bucket=False,
|
||||
fragment_interval=0.0,
|
||||
super_sampling=False,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
def _run_finalize_loop(self) -> None:
|
||||
while True:
|
||||
tasks = self._take_finalize_task_batch()
|
||||
try:
|
||||
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
|
||||
finalize_wait_request_ids: List[str] = []
|
||||
with self.condition:
|
||||
for task in tasks:
|
||||
job = self.job_map.get(task.request_id)
|
||||
if job is None:
|
||||
continue
|
||||
jobs_and_items.append((job, task.item))
|
||||
finalize_wait_request_ids.append(task.request_id)
|
||||
if not jobs_and_items:
|
||||
continue
|
||||
now = time.perf_counter()
|
||||
for task in tasks:
|
||||
self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0))
|
||||
self._sync_device()
|
||||
synth_start = time.perf_counter()
|
||||
sample_rate, audio_data = self._synthesize_finished_audio(job, item)
|
||||
if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder:
|
||||
job, item = jobs_and_items[0]
|
||||
batch_results = [self._synthesize_finished_audio(job, item)]
|
||||
else:
|
||||
batch_results = self._synthesize_finished_audio_batch(jobs_and_items)
|
||||
self._sync_device()
|
||||
synth_ms = (time.perf_counter() - synth_start) * 1000.0
|
||||
with self.condition:
|
||||
for job, _ in jobs_and_items:
|
||||
tracked_job = self.job_map.get(job.request_id)
|
||||
if tracked_job is not None:
|
||||
tracked_job.synth_ms += synth_ms
|
||||
for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results):
|
||||
self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data)
|
||||
except Exception as exc:
|
||||
self._finalize_error([item.request_id], str(exc))
|
||||
continue
|
||||
|
||||
finished_at = time.perf_counter()
|
||||
with self.condition:
|
||||
if self.job_map.get(item.request_id) is not job:
|
||||
continue
|
||||
queue_wait_ms = 0.0
|
||||
if job.first_schedule_time is not None:
|
||||
queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0)
|
||||
worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0)
|
||||
job.synth_ms += synth_ms
|
||||
job.sample_rate = int(sample_rate)
|
||||
job.audio_data = audio_data
|
||||
prepare_profile = dict(job.state.prepare_profile)
|
||||
job.result = {
|
||||
"request_id": item.request_id,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
"finish_idx": int(item.finish_idx),
|
||||
"finish_reason": item.finish_reason,
|
||||
"prepare_ms": job.prepare_ms,
|
||||
"prepare_wall_ms": job.prepare_wall_ms,
|
||||
"prepare_profile": prepare_profile,
|
||||
"queue_wait_ms": queue_wait_ms,
|
||||
"prefill_ms": job.prefill_ms,
|
||||
"decode_ms": job.decode_ms,
|
||||
"synth_ms": job.synth_ms,
|
||||
"worker_total_ms": worker_total_ms,
|
||||
"decode_steps": int(job.decode_steps),
|
||||
"sample_rate": int(sample_rate),
|
||||
"media_type": job.media_type,
|
||||
}
|
||||
job.done_event.set()
|
||||
self.job_map.pop(item.request_id, None)
|
||||
self.total_finished += 1
|
||||
self._finalize_error([task.request_id for task in tasks], str(exc))
|
||||
finally:
|
||||
self._finalize_task_done(len(tasks))
|
||||
|
||||
def _finalize_error(self, request_ids: List[str], error: str) -> None:
|
||||
if not request_ids:
|
||||
@ -469,12 +711,28 @@ class SchedulerDebugWorker:
|
||||
continue
|
||||
job.error = error
|
||||
job.done_event.set()
|
||||
self._notify_done_future(job)
|
||||
self.job_map.pop(request_id, None)
|
||||
self.total_finished += 1
|
||||
|
||||
@staticmethod
|
||||
def _resolve_done_future(job: SchedulerPendingJob) -> None:
|
||||
future = job.done_future
|
||||
if future is None or future.done():
|
||||
return
|
||||
future.set_result(True)
|
||||
|
||||
def _notify_done_future(self, job: SchedulerPendingJob) -> None:
|
||||
if job.done_loop is None or job.done_future is None:
|
||||
return
|
||||
try:
|
||||
job.done_loop.call_soon_threadsafe(self._resolve_done_future, job)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
|
||||
with self.condition:
|
||||
if not self.pending_jobs and not self.running_requests:
|
||||
if not self.pending_jobs and self.active_batch is None:
|
||||
self.condition.wait(timeout=self.micro_batch_wait_s)
|
||||
elif wait_for_batch and self.pending_jobs:
|
||||
self.condition.wait(timeout=self.micro_batch_wait_s)
|
||||
@ -482,11 +740,13 @@ class SchedulerDebugWorker:
|
||||
return []
|
||||
pending = list(self.pending_jobs)
|
||||
self.pending_jobs.clear()
|
||||
with self.finalize_condition:
|
||||
self.finalize_condition.notify_all()
|
||||
return pending
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
while True:
|
||||
wait_for_batch = len(self.running_requests) == 0
|
||||
wait_for_batch = self.active_batch is None
|
||||
pending_jobs = self._take_pending_snapshot(wait_for_batch=wait_for_batch)
|
||||
|
||||
if pending_jobs:
|
||||
@ -494,37 +754,54 @@ class SchedulerDebugWorker:
|
||||
self._sync_device()
|
||||
prefill_start = time.perf_counter()
|
||||
self._mark_prefill_started(pending_jobs, prefill_start)
|
||||
admitted_running, admitted_finished = run_prefill_step(
|
||||
admitted_active_batch, admitted_finished = run_prefill_active_batch(
|
||||
self.tts.t2s_model.model,
|
||||
[job.state for job in pending_jobs],
|
||||
max_steps=self.max_steps,
|
||||
)
|
||||
self._sync_device()
|
||||
self._add_prefill_time(pending_jobs, time.perf_counter() - prefill_start)
|
||||
self._finalize_finished(admitted_finished)
|
||||
self.running_requests.extend(admitted_running)
|
||||
self._enqueue_finalize_finished(admitted_finished)
|
||||
merge_start = time.perf_counter()
|
||||
self.active_batch = merge_active_batches(
|
||||
self.tts.t2s_model.model,
|
||||
self.active_batch,
|
||||
admitted_active_batch,
|
||||
)
|
||||
self._add_merge_time(
|
||||
[] if self.active_batch is None else list(self.active_batch.request_ids),
|
||||
time.perf_counter() - merge_start,
|
||||
)
|
||||
with self.finalize_condition:
|
||||
self.finalize_condition.notify_all()
|
||||
except Exception as exc:
|
||||
self._finalize_error([job.request_id for job in pending_jobs], str(exc))
|
||||
|
||||
if self.running_requests:
|
||||
if self.active_batch is not None:
|
||||
try:
|
||||
active_request_ids = [item.state.request_id for item in self.running_requests]
|
||||
active_request_ids = [state.request_id for state in self.active_batch.states]
|
||||
self._sync_device()
|
||||
decode_start = time.perf_counter()
|
||||
self.running_requests, step_finished = run_decode_step_for_running(
|
||||
self.active_batch, step_finished = decode_one_step(
|
||||
self.tts.t2s_model.model,
|
||||
self.running_requests,
|
||||
self.active_batch,
|
||||
max_steps=self.max_steps,
|
||||
)
|
||||
self._sync_device()
|
||||
self._add_decode_time(active_request_ids, time.perf_counter() - decode_start)
|
||||
self._finalize_finished(step_finished)
|
||||
self._enqueue_finalize_finished(step_finished)
|
||||
with self.finalize_condition:
|
||||
self.finalize_condition.notify_all()
|
||||
except Exception as exc:
|
||||
self._finalize_error(active_request_ids, str(exc))
|
||||
self.running_requests = []
|
||||
self.active_batch = None
|
||||
with self.finalize_condition:
|
||||
self.finalize_condition.notify_all()
|
||||
continue
|
||||
|
||||
if not pending_jobs:
|
||||
with self.finalize_condition:
|
||||
self.finalize_condition.notify_all()
|
||||
time.sleep(self.micro_batch_wait_s)
|
||||
|
||||
|
||||
@ -788,10 +1065,6 @@ def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
|
||||
]
|
||||
|
||||
|
||||
def prepare_scheduler_states_batch(specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]:
|
||||
return [scheduler_debug_worker.prepare_state(spec) for spec in specs]
|
||||
|
||||
|
||||
def build_scheduler_submit_spec(request: Scheduler_Submit_Request) -> SchedulerRequestSpec:
|
||||
payload = request.dict()
|
||||
request_id = payload["request_id"] or f"job_{uuid.uuid4().hex[:12]}"
|
||||
@ -845,7 +1118,7 @@ async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request):
|
||||
try:
|
||||
set_scheduler_seed(request.seed)
|
||||
specs = build_scheduler_request_specs(request.requests)
|
||||
states = await asyncio.to_thread(prepare_scheduler_states_batch, specs)
|
||||
states = await scheduler_debug_worker.prepare_states_batch_async(specs)
|
||||
finished = run_scheduler_continuous(tts_pipeline.t2s_model.model, states, max_steps=int(request.max_steps))
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
@ -867,20 +1140,51 @@ async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request):
|
||||
async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
|
||||
try:
|
||||
request_start = time.perf_counter()
|
||||
prepare_start = request_start
|
||||
spec = build_scheduler_submit_spec(request)
|
||||
prepare_start = time.perf_counter()
|
||||
state = await asyncio.to_thread(scheduler_debug_worker.prepare_state, spec)
|
||||
prepare_wall_ms = (time.perf_counter() - prepare_start) * 1000.0
|
||||
prepare_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms))
|
||||
spec_ready_at = time.perf_counter()
|
||||
prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0)
|
||||
state, prepare_exec_started_at, prepare_exec_finished_at = await scheduler_debug_worker.prepare_state_profiled_async(
|
||||
spec,
|
||||
spec_ready_at,
|
||||
)
|
||||
prepare_end = time.perf_counter()
|
||||
prepare_wall_ms = (prepare_end - prepare_start) * 1000.0
|
||||
prepare_profile_total_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms))
|
||||
prepare_profile_wall_ms = float(state.prepare_profile.get("wall_total_ms", prepare_profile_total_ms))
|
||||
prepare_executor_queue_ms = float(
|
||||
state.prepare_profile.get("executor_queue_ms", max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0))
|
||||
)
|
||||
prepare_executor_run_ms = float(
|
||||
state.prepare_profile.get(
|
||||
"executor_run_wall_ms",
|
||||
max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0),
|
||||
)
|
||||
)
|
||||
prepare_other_ms = max(
|
||||
0.0,
|
||||
prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_profile_wall_ms,
|
||||
)
|
||||
loop = asyncio.get_running_loop()
|
||||
done_future = loop.create_future()
|
||||
job = scheduler_debug_worker.submit(
|
||||
state,
|
||||
speed_factor=float(request.speed_factor),
|
||||
sample_steps=int(request.sample_steps),
|
||||
media_type=request.media_type,
|
||||
prepare_ms=prepare_ms,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=loop,
|
||||
done_future=done_future,
|
||||
)
|
||||
timeout_ok = await asyncio.to_thread(job.done_event.wait, float(request.timeout_sec))
|
||||
api_after_prepare_ms = max(0.0, (job.enqueue_time - prepare_end) * 1000.0)
|
||||
timeout_ok = False
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(done_future), timeout=float(request.timeout_sec))
|
||||
timeout_ok = True
|
||||
except asyncio.TimeoutError:
|
||||
timeout_ok = False
|
||||
wait_return_at = time.perf_counter()
|
||||
if not timeout_ok:
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
@ -888,8 +1192,10 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
|
||||
"message": "queued",
|
||||
"request_id": job.request_id,
|
||||
"timings": {
|
||||
"prepare_ms": prepare_ms,
|
||||
"prepare_ms": prepare_wall_ms,
|
||||
"prepare_wall_ms": prepare_wall_ms,
|
||||
"prepare_profile_total_ms": prepare_profile_total_ms,
|
||||
"api_after_prepare_ms": api_after_prepare_ms,
|
||||
"request_elapsed_ms": max(0.0, (time.perf_counter() - request_start) * 1000.0),
|
||||
},
|
||||
"worker_state": scheduler_debug_worker.get_state(),
|
||||
@ -911,9 +1217,13 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
|
||||
)
|
||||
pack_start = time.perf_counter()
|
||||
audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue()
|
||||
pack_ms = (time.perf_counter() - pack_start) * 1000.0
|
||||
pack_end = time.perf_counter()
|
||||
pack_ms = (pack_end - pack_start) * 1000.0
|
||||
job.pack_ms = pack_ms
|
||||
request_total_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0)
|
||||
api_wait_result_ms = 0.0
|
||||
if job.result_ready_time is not None:
|
||||
api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0)
|
||||
worker_total_ms = float(job.result["worker_total_ms"]) if job.result is not None else 0.0
|
||||
headers = {
|
||||
"X-Request-Id": job.request_id,
|
||||
"X-Semantic-Len": str(job.result["semantic_len"]) if job.result is not None else "0",
|
||||
@ -921,16 +1231,32 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
|
||||
"X-Queue-Wait-Ms": (
|
||||
f"{float(job.result['queue_wait_ms']):.3f}" if job.result is not None else "0.000"
|
||||
),
|
||||
"X-Prepare-Ms": f"{prepare_ms:.3f}",
|
||||
"X-Prepare-Ms": f"{prepare_wall_ms:.3f}",
|
||||
"X-Prepare-Wall-Ms": f"{prepare_wall_ms:.3f}",
|
||||
"X-Prepare-Spec-Build-Ms": f"{prepare_spec_build_ms:.3f}",
|
||||
"X-Prepare-Executor-Queue-Ms": f"{prepare_executor_queue_ms:.3f}",
|
||||
"X-Prepare-Admission-Wait-Ms": (
|
||||
f"{float(job.result['prepare_profile'].get('prepare_admission_wait_ms', 0.0)):.3f}"
|
||||
if job.result is not None
|
||||
else "0.000"
|
||||
),
|
||||
"X-Prepare-Executor-Run-Ms": f"{prepare_executor_run_ms:.3f}",
|
||||
"X-Prepare-Profile-Total-Ms": f"{prepare_profile_total_ms:.3f}",
|
||||
"X-Prepare-Profile-Wall-Ms": f"{prepare_profile_wall_ms:.3f}",
|
||||
"X-Prepare-Other-Ms": f"{prepare_other_ms:.3f}",
|
||||
"X-Api-After-Prepare-Ms": f"{api_after_prepare_ms:.3f}",
|
||||
"X-Prefill-Ms": f"{float(job.result['prefill_ms']):.3f}" if job.result is not None else "0.000",
|
||||
"X-Merge-Ms": f"{float(job.result['merge_ms']):.3f}" if job.result is not None else "0.000",
|
||||
"X-Decode-Ms": f"{float(job.result['decode_ms']):.3f}" if job.result is not None else "0.000",
|
||||
"X-Finalize-Wait-Ms": f"{float(job.result['finalize_wait_ms']):.3f}" if job.result is not None else "0.000",
|
||||
"X-Synth-Ms": f"{float(job.result['synth_ms']):.3f}" if job.result is not None else "0.000",
|
||||
"X-Worker-Residual-Ms": f"{float(job.result['worker_residual_ms']):.3f}" if job.result is not None else "0.000",
|
||||
"X-Worker-Other-Ms": f"{float(job.result['worker_other_ms']):.3f}" if job.result is not None else "0.000",
|
||||
"X-Pack-Ms": f"{pack_ms:.3f}",
|
||||
"X-Worker-Total-Ms": (
|
||||
f"{float(job.result['worker_total_ms']):.3f}" if job.result is not None else "0.000"
|
||||
),
|
||||
"X-Request-Total-Ms": f"{request_total_ms:.3f}",
|
||||
"X-Api-Wait-Result-Ms": f"{api_wait_result_ms:.3f}",
|
||||
"X-Decode-Steps": str(job.result["decode_steps"]) if job.result is not None else "0",
|
||||
}
|
||||
if job.result is not None:
|
||||
@ -939,16 +1265,48 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
|
||||
{
|
||||
"X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('prompt_text_cpu_preprocess_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('text_cpu_preprocess_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('prompt_text_cpu_queue_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('text_cpu_queue_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('prompt_text_feature_queue_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('text_feature_queue_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_admission_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('text_bert_admission_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_queue_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('text_bert_queue_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_collect_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('text_bert_batch_collect_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(
|
||||
int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(
|
||||
int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(
|
||||
int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(
|
||||
int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Prompt-Bert-High-Pressure-Peak": str(
|
||||
int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Target-Bert-High-Pressure-Peak": str(
|
||||
int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Prompt-Bert-Batch-Size-Peak": str(
|
||||
int(prepare_profile.get("prompt_text_bert_batch_size_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Target-Bert-Batch-Size-Peak": str(
|
||||
int(prepare_profile.get("text_bert_batch_size_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Prompt-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_window_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('text_bert_batch_window_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))),
|
||||
"X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}",
|
||||
@ -964,13 +1322,22 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
|
||||
"X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Profile-Wall-Ms": f"{float(prepare_profile.get('wall_total_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Inflight-On-Enter": str(
|
||||
int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0))
|
||||
),
|
||||
"X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))),
|
||||
}
|
||||
)
|
||||
response_ready_at = time.perf_counter()
|
||||
response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0)
|
||||
request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0)
|
||||
request_other_ms = max(
|
||||
0.0,
|
||||
request_total_ms - prepare_wall_ms - api_after_prepare_ms - worker_total_ms - api_wait_result_ms - pack_ms,
|
||||
)
|
||||
headers["X-Response-Overhead-Ms"] = f"{response_overhead_ms:.3f}"
|
||||
headers["X-Request-Other-Ms"] = f"{request_other_ms:.3f}"
|
||||
headers["X-Request-Total-Ms"] = f"{request_total_ms:.3f}"
|
||||
return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user