Implement batch processing for BERT and reference semantic tasks in TTS. Introduce StageLimiter for managing concurrent processing and enhance the TTS class with new methods for handling audio and semantic extraction. Update profiling metrics for better performance tracking during inference.

This commit is contained in:
baicai-1145 2026-03-09 05:19:28 +08:00
parent d245eb169c
commit 845b181360
6 changed files with 1024 additions and 140 deletions

View File

@ -5,6 +5,7 @@ import random
import sys import sys
import time import time
import traceback import traceback
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy from copy import deepcopy
import torchaudio import torchaudio
@ -33,7 +34,12 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
from tools.audio_sr import AP_BWE from tools.audio_sr import AP_BWE
from tools.i18n.i18n import I18nAuto, scan_language_list from tools.i18n.i18n import I18nAuto, scan_language_list
from TTS_infer_pack.text_segmentation_method import splits from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor from TTS_infer_pack.TextPreprocessor import TextPreprocessor, StageLimiter
from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker
from TTS_infer_pack.prepare_ref_semantic_batch_worker import (
PrepareRefSemanticBatchWorker,
prepare_prompt_semantic_wav16k,
)
from sv import SV from sv import SV
resample_transform_dict = {} resample_transform_dict = {}
@ -442,11 +448,56 @@ class TTS:
"upsample_rate": None, "upsample_rate": None,
"overlapped_len": None, "overlapped_len": None,
} }
self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1")))
self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "2")))
self.prepare_bert_batch_worker = None
self.prepare_ref_semantic_batch_worker = None
default_text_cpu_workers = 16
self.prepare_text_cpu_workers = max(
0,
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", str(default_text_cpu_workers))),
)
self.prepare_text_cpu_executor = None
if self.prepare_text_cpu_workers > 0:
self.prepare_text_cpu_executor = ThreadPoolExecutor(
max_workers=self.prepare_text_cpu_workers,
thread_name_prefix="prepare-text-cpu",
)
self._init_models() self._init_models()
if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0":
self.prepare_bert_batch_worker = PrepareBertBatchWorker(
bert_model=self.bert_model,
tokenizer=self.bert_tokenizer,
device=self.configs.device,
stage_limiter=self.prepare_bert_stage_limiter,
batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_WINDOW_MS", "5")),
max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_ITEMS", "16")),
max_batch_tokens=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_TOKENS", "4096")),
)
if os.environ.get("GPTSOVITS_PREPARE_REF_BATCHING", "0") != "0":
ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_SAMPLES")
if ref_max_batch_samples is None:
ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_FRAMES", "960000")
self.prepare_ref_semantic_batch_worker = PrepareRefSemanticBatchWorker(
ssl_model=self.cnhuhbert_model,
vits_model=self.vits_model,
device=self.configs.device,
is_half=self.configs.is_half,
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
stage_limiter=self.prepare_ref_audio_stage_limiter,
batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_WINDOW_MS", "5")),
max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_ITEMS", "8")),
max_batch_samples=int(ref_max_batch_samples),
)
self.text_preprocessor: TextPreprocessor = TextPreprocessor( self.text_preprocessor: TextPreprocessor = TextPreprocessor(
self.bert_model, self.bert_tokenizer, self.configs.device self.bert_model,
self.bert_tokenizer,
self.configs.device,
bert_stage_limiter=self.prepare_bert_stage_limiter,
bert_batch_worker=self.prepare_bert_batch_worker,
) )
self.prompt_cache: dict = { self.prompt_cache: dict = {
@ -755,47 +806,52 @@ class TTS:
Args: Args:
ref_audio_path: str, the path of the reference audio. ref_audio_path: str, the path of the reference audio.
""" """
self._set_prompt_semantic(ref_audio_path) bundle = self.extract_ref_audio_bundle(ref_audio_path)
self._set_ref_spec(ref_audio_path) if self.prompt_cache["refer_spec"] in [[], None]:
self.prompt_cache["refer_spec"] = [bundle["refer_spec"]]
else:
self.prompt_cache["refer_spec"][0] = bundle["refer_spec"]
self.prompt_cache["prompt_semantic"] = bundle["prompt_semantic"]
self.prompt_cache["raw_audio"] = bundle["raw_audio"]
self.prompt_cache["raw_sr"] = bundle["raw_sr"]
self._set_ref_audio_path(ref_audio_path) self._set_ref_audio_path(ref_audio_path)
def extract_prompt_semantic(self, ref_wav_path: str): def _load_ref_audio_raw(self, ref_audio_path: str):
zero_wav = np.zeros(
int(self.configs.sampling_rate * 0.3),
dtype=np.float16 if self.configs.is_half else np.float32,
)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
raise OSError(i18n("参考音频在3~10秒范围外请更换"))
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
wav16k = wav16k.to(self.configs.device)
zero_wav_torch = zero_wav_torch.to(self.configs.device)
if self.configs.is_half:
wav16k = wav16k.half()
zero_wav_torch = zero_wav_torch.half()
wav16k = torch.cat([wav16k, zero_wav_torch])
hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(
1, 2
) # .float()
codes = self.vits_model.extract_latent(hubert_feature)
prompt_semantic = codes[0, 0].to(self.configs.device)
return prompt_semantic
def extract_ref_spec(self, ref_audio_path: str):
raw_audio, raw_sr = torchaudio.load(ref_audio_path) raw_audio, raw_sr = torchaudio.load(ref_audio_path)
raw_audio = raw_audio.to(self.configs.device).float() return raw_audio.float(), int(raw_sr)
@torch.inference_mode()
def _extract_prompt_semantic_from_prepared_wav16k(self, wav16k: torch.Tensor):
wav16k = wav16k.to(self.configs.device)
if self.configs.is_half:
wav16k = wav16k.half()
hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)
codes = self.vits_model.extract_latent(hubert_feature)
return codes[0, 0].to(self.configs.device)
@torch.inference_mode()
def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
wav16k = prepare_prompt_semantic_wav16k(
raw_audio=raw_audio,
raw_sr=raw_sr,
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
)
return self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
def extract_prompt_semantic(self, ref_wav_path: str):
raw_audio, raw_sr = self._load_ref_audio_raw(ref_wav_path)
return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr)
def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
raw_audio_device = raw_audio.to(self.configs.device).float()
if raw_sr != self.configs.sampling_rate: if raw_sr != self.configs.sampling_rate:
audio = raw_audio.to(self.configs.device) audio = raw_audio_device
if audio.shape[0] == 2: if audio.shape[0] == 2:
audio = audio.mean(0).unsqueeze(0) audio = audio.mean(0).unsqueeze(0)
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device) audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
else: else:
audio = raw_audio.to(self.configs.device) audio = raw_audio_device
if audio.shape[0] == 2: if audio.shape[0] == 2:
audio = audio.mean(0).unsqueeze(0) audio = audio.mean(0).unsqueeze(0)
@ -820,8 +876,141 @@ class TTS:
audio = None audio = None
return spec, audio, raw_audio, raw_sr return spec, audio, raw_audio, raw_sr
def extract_text_features(self, text: str, language: str): def extract_ref_spec(self, ref_audio_path: str):
return self.text_preprocessor.segment_and_extract_feature_for_text(text, language, self.configs.version) raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
return self._extract_ref_spec_from_raw(raw_audio, raw_sr)
def extract_ref_audio_bundle(self, ref_audio_path: str):
load_start = time.perf_counter()
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
load_ms = (time.perf_counter() - load_start) * 1000.0
if self.prepare_ref_semantic_batch_worker is None:
with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats:
prompt_semantic_start = time.perf_counter()
prompt_semantic = self._extract_prompt_semantic_from_raw(raw_audio, raw_sr)
prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0
ref_spec_start = time.perf_counter()
refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0
audio_stage_wait_ms = float(limiter_stats["wait_ms"])
audio_stage_slots = float(limiter_stats["slots"])
audio_stage_inflight_peak = float(limiter_stats["peak_inflight"])
prompt_semantic_profile = {
"prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]),
"prompt_semantic_cpu_prepare_ms": 0.0,
"prompt_semantic_forward_ms": prompt_semantic_ms,
"prompt_semantic_scatter_ms": 0.0,
"prompt_semantic_stage_slots": float(limiter_stats["slots"]),
"prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
"prompt_semantic_batch_size": 1.0,
"prompt_semantic_batch_samples": 0.0,
}
ref_spec_wait_ms = 0.0
return {
"prompt_semantic": prompt_semantic,
"refer_spec": refer_spec,
"raw_audio": raw_audio,
"raw_sr": raw_sr,
"profile": {
"audio_load_ms": load_ms,
"audio_stage_wait_ms": audio_stage_wait_ms,
"audio_stage_slots": audio_stage_slots,
"audio_stage_inflight_peak": audio_stage_inflight_peak,
"prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_ms": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
),
"prompt_semantic_forward_ms": float(
prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)
),
"prompt_semantic_scatter_ms": float(
prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)
),
"prompt_semantic_stage_slots": float(
prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)
),
"prompt_semantic_stage_inflight_peak": float(
prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)
),
"prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)),
"prompt_semantic_batch_samples": float(
prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)
),
"ref_spec_wait_ms": ref_spec_wait_ms,
"ref_spec_ms": ref_spec_ms,
"bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms,
},
}
prompt_semantic_profile = {
"prompt_semantic_wait_ms": 0.0,
"prompt_semantic_cpu_prepare_ms": 0.0,
"prompt_semantic_forward_ms": 0.0,
"prompt_semantic_scatter_ms": 0.0,
"prompt_semantic_stage_slots": 0.0,
"prompt_semantic_stage_inflight_peak": 0.0,
"prompt_semantic_batch_size": 1.0,
"prompt_semantic_batch_samples": 0.0,
}
if self.prepare_ref_semantic_batch_worker is not None:
prompt_semantic, worker_profile = self.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr)
prompt_semantic_profile.update(worker_profile)
prompt_semantic_ms = (
float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0))
+ float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0))
+ float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0))
)
with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats:
ref_spec_start = time.perf_counter()
refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0
audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float(
ref_spec_limiter_stats["wait_ms"]
)
audio_stage_slots = max(
float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
float(ref_spec_limiter_stats["slots"]),
)
audio_stage_inflight_peak = max(
float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
float(ref_spec_limiter_stats["peak_inflight"]),
)
return {
"prompt_semantic": prompt_semantic,
"refer_spec": refer_spec,
"raw_audio": raw_audio,
"raw_sr": raw_sr,
"profile": {
"audio_load_ms": load_ms,
"audio_stage_wait_ms": audio_stage_wait_ms,
"audio_stage_slots": audio_stage_slots,
"audio_stage_inflight_peak": audio_stage_inflight_peak,
"prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_ms": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
),
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),
"prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
"prompt_semantic_stage_inflight_peak": float(
prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)
),
"prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)),
"prompt_semantic_batch_samples": float(
prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)
),
"ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]),
"ref_spec_ms": ref_spec_ms,
"bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms,
},
}
def extract_text_features(self, text: str, language: str, profile: dict | None = None):
return self.text_preprocessor.segment_and_extract_feature_for_text(
text, language, self.configs.version, profile=profile
)
def _set_ref_audio_path(self, ref_audio_path): def _set_ref_audio_path(self, ref_audio_path):
self.prompt_cache["ref_audio_path"] = ref_audio_path self.prompt_cache["ref_audio_path"] = ref_audio_path

View File

@ -1,6 +1,8 @@
import os import os
import sys import sys
import threading import threading
import time
from contextlib import contextmanager
from tqdm import tqdm from tqdm import tqdm
@ -16,6 +18,7 @@ from text.cleaner import clean_text
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker
from tools.i18n.i18n import I18nAuto, scan_language_list from tools.i18n.i18n import I18nAuto, scan_language_list
@ -49,12 +52,60 @@ def merge_short_text_in_array(texts: str, threshold: int) -> list:
return result return result
class StageLimiter:
def __init__(self, slots: int):
self.slots = max(1, int(slots))
self.semaphore = threading.BoundedSemaphore(self.slots)
self.lock = threading.Lock()
self.inflight = 0
self.peak_inflight = 0
@contextmanager
def enter(self):
wait_start = time.perf_counter()
self.semaphore.acquire()
wait_ms = (time.perf_counter() - wait_start) * 1000.0
with self.lock:
self.inflight += 1
current_inflight = self.inflight
if current_inflight > self.peak_inflight:
self.peak_inflight = current_inflight
peak_inflight = self.peak_inflight
try:
yield {
"wait_ms": wait_ms,
"inflight": current_inflight,
"peak_inflight": peak_inflight,
"slots": self.slots,
}
finally:
with self.lock:
self.inflight = max(0, self.inflight - 1)
self.semaphore.release()
def snapshot(self) -> Dict[str, int]:
with self.lock:
return {
"slots": self.slots,
"inflight": self.inflight,
"peak_inflight": self.peak_inflight,
}
class TextPreprocessor: class TextPreprocessor:
def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device): def __init__(
self,
bert_model: AutoModelForMaskedLM,
tokenizer: AutoTokenizer,
device: torch.device,
bert_stage_limiter: StageLimiter | None = None,
bert_batch_worker: PrepareBertBatchWorker | None = None,
):
self.bert_model = bert_model self.bert_model = bert_model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.device = device self.device = device
self.bert_lock = threading.RLock() self.bert_stage_limiter = bert_stage_limiter
self.bert_batch_worker = bert_batch_worker
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############") print(f"############ {i18n('切分文本')} ############")
@ -115,86 +166,136 @@ class TextPreprocessor:
return texts return texts
def segment_and_extract_feature_for_text( def segment_and_extract_feature_for_text(
self, text: str, language: str, version: str = "v1" self, text: str, language: str, version: str = "v1", profile: Dict | None = None
) -> Tuple[list, torch.Tensor, str]: ) -> Tuple[list, torch.Tensor, str]:
return self.get_phones_and_bert(text, language, version) return self.get_phones_and_bert(text, language, version, profile=profile)
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False): def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]:
with self.bert_lock: textlist = []
text = re.sub(r' {2,}', ' ', text) langlist = []
textlist = [] if language == "all_zh":
langlist = [] for tmp in LangSegmenter.getTexts(text, "zh"):
if language == "all_zh": langlist.append(tmp["lang"])
for tmp in LangSegmenter.getTexts(text,"zh"): textlist.append(tmp["text"])
elif language == "all_yue":
for tmp in LangSegmenter.getTexts(text, "zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text, "ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text, "ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or (
tmp["lang"] != "en" and langlist[-1] != "en"
)
if same_group:
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"]) langlist.append(tmp["lang"])
textlist.append(tmp["text"]) else:
elif language == "all_yue": langlist.append(language)
for tmp in LangSegmenter.getTexts(text,"zh"): textlist.append(tmp["text"])
if tmp["lang"] == "zh": return textlist, langlist
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text,"ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text,"ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日韩文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
# print(textlist)
# print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)
if not final and len(phones) < 6: def get_phones_and_bert(
return self.get_phones_and_bert("." + text, language, version, final=True) self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None
):
text = re.sub(r' {2,}', ' ', text)
textlist, langlist = self._split_text_by_language(text, language)
phones_list = []
bert_list = []
norm_text_list = []
for segment_text, segment_lang in zip(textlist, langlist):
phones, word2ph, norm_text = self.clean_text_inf(segment_text, segment_lang, version)
bert = self.get_bert_inf(phones, word2ph, norm_text, segment_lang, profile=profile)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)
return phones, bert, norm_text if not final and len(phones) < 6:
return self.get_phones_and_bert("." + text, language, version, final=True, profile=profile)
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor: return phones, bert, norm_text
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt") def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None:
for i in inputs: if profile is None:
inputs[i] = inputs[i].to(self.device) return
res = self.bert_model(**inputs, output_hidden_states=True) profile[key] = float(profile.get(key, 0.0)) + float(value)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
def _update_profile_peak(self, profile: Dict | None, key: str, value: float) -> None:
if profile is None:
return
profile[key] = float(max(float(profile.get(key, 0.0)), float(value)))
def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor:
if self.bert_batch_worker is not None:
feature, worker_profile = self.bert_batch_worker.submit(text, word2ph)
self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0))
self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0))
self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0))
self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0))
self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0))
self._update_profile_peak(
profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0)
)
self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0))
self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0))
if profile is not None:
profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0))
return feature
limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0}
if self.bert_stage_limiter is None:
forward_start = time.perf_counter()
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
forward_ms = (time.perf_counter() - forward_start) * 1000.0
else:
with self.bert_stage_limiter.enter() as limiter_stats:
forward_start = time.perf_counter()
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
forward_ms = (time.perf_counter() - forward_start) * 1000.0
self._accumulate_profile(profile, "bert_wait_ms", limiter_stats["wait_ms"])
self._accumulate_profile(profile, "bert_forward_ms", forward_ms)
self._accumulate_profile(profile, "bert_calls", 1.0)
self._update_profile_peak(profile, "bert_stage_inflight_peak", limiter_stats["peak_inflight"])
if profile is not None:
profile["bert_stage_slots"] = float(limiter_stats["slots"])
assert len(word2ph) == len(text) assert len(word2ph) == len(text)
phone_level_feature = [] phone_level_feature = []
for i in range(len(word2ph)): for i in range(len(word2ph)):
@ -209,10 +310,10 @@ class TextPreprocessor:
phones = cleaned_text_to_sequence(phones, version) phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text return phones, word2ph, norm_text
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str): def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str, profile: Dict | None = None):
language = language.replace("all_", "") language = language.replace("all_", "")
if language == "zh": if language == "zh":
feature = self.get_bert_feature(norm_text, word2ph).to(self.device) feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device)
else: else:
feature = torch.zeros( feature = torch.zeros(
(1024, len(phones)), (1024, len(phones)),
@ -236,4 +337,4 @@ class TextPreprocessor:
punctuations = "".join(re.escape(p) for p in punctuation) punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f"([{punctuations}])([{punctuations}])+" pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text) result = re.sub(pattern, r"\1", text)
return result return result

View File

@ -0,0 +1,197 @@
import threading
import time
import uuid
from collections import deque
from dataclasses import dataclass, field
from typing import Deque, Dict, List, Tuple
import torch
@dataclass
class BertFeatureTask:
norm_text: str
word2ph: List[int]
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
created_at: float = field(default_factory=time.perf_counter)
done_event: threading.Event = field(default_factory=threading.Event)
result_feature: torch.Tensor | None = None
error: Exception | None = None
profile: Dict[str, float] = field(default_factory=dict)
class PrepareBertBatchWorker:
def __init__(
self,
bert_model,
tokenizer,
device,
stage_limiter=None,
batch_window_ms: int = 5,
max_batch_items: int = 16,
max_batch_tokens: int = 4096,
):
self.bert_model = bert_model
self.tokenizer = tokenizer
self.device = device
self.stage_limiter = stage_limiter
self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0)
self.max_batch_items = max(1, int(max_batch_items))
self.max_batch_tokens = max(16, int(max_batch_tokens))
self.condition = threading.Condition()
self.pending_tasks: Deque[BertFeatureTask] = deque()
self.pending_peak = 0
self.total_submitted = 0
self.total_finished = 0
self.total_batches = 0
self.active_batch_size = 0
self.active_batch_peak = 0
self.worker_thread = threading.Thread(target=self._run_loop, name="prepare-bert-batch-worker", daemon=True)
self.worker_thread.start()
def _estimate_task_tokens(self, task: BertFeatureTask) -> int:
return max(1, len(task.norm_text) + 2)
def submit(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]:
task = BertFeatureTask(norm_text=str(norm_text), word2ph=list(word2ph))
with self.condition:
self.pending_tasks.append(task)
self.total_submitted += 1
if len(self.pending_tasks) > self.pending_peak:
self.pending_peak = len(self.pending_tasks)
self.condition.notify_all()
task.done_event.wait()
if task.error is not None:
raise task.error
assert task.result_feature is not None
return task.result_feature, dict(task.profile)
def snapshot(self) -> Dict[str, int]:
with self.condition:
return {
"pending": len(self.pending_tasks),
"pending_peak": self.pending_peak,
"total_submitted": self.total_submitted,
"total_finished": self.total_finished,
"total_batches": self.total_batches,
"active_batch_size": self.active_batch_size,
"active_batch_peak": self.active_batch_peak,
"batch_window_ms": int(self.batch_window_s * 1000.0),
"max_batch_items": self.max_batch_items,
"max_batch_tokens": self.max_batch_tokens,
}
def _collect_batch(self) -> List[BertFeatureTask]:
with self.condition:
while not self.pending_tasks:
self.condition.wait()
batch: List[BertFeatureTask] = [self.pending_tasks.popleft()]
batch_tokens = self._estimate_task_tokens(batch[0])
deadline = time.perf_counter() + self.batch_window_s
while len(batch) < self.max_batch_items:
remaining = deadline - time.perf_counter()
if remaining <= 0:
break
if not self.pending_tasks:
self.condition.wait(timeout=remaining)
continue
next_task = self.pending_tasks[0]
next_tokens = self._estimate_task_tokens(next_task)
if len(batch) >= self.max_batch_items or (batch_tokens + next_tokens) > self.max_batch_tokens:
break
batch.append(self.pending_tasks.popleft())
batch_tokens += next_tokens
self.active_batch_size = len(batch)
if self.active_batch_size > self.active_batch_peak:
self.active_batch_peak = self.active_batch_size
return batch
def _finalize_batch(self, batch: List[BertFeatureTask]) -> None:
with self.condition:
self.active_batch_size = 0
self.total_batches += 1
self.total_finished += len(batch)
def _run_batch(self, batch: List[BertFeatureTask]) -> None:
batch_started = time.perf_counter()
texts = [task.norm_text for task in batch]
batch_tokens = sum(self._estimate_task_tokens(task) for task in batch)
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
if self.stage_limiter is None:
tokenize_start = time.perf_counter()
inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0
attention_mask_cpu = inputs["attention_mask"].cpu()
for key in inputs:
inputs[key] = inputs[key].to(self.device)
forward_start = time.perf_counter()
with torch.no_grad():
outputs = self.bert_model(**inputs, output_hidden_states=True)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
else:
with self.stage_limiter.enter() as limiter_stats:
tokenize_start = time.perf_counter()
inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0
attention_mask_cpu = inputs["attention_mask"].cpu()
for key in inputs:
inputs[key] = inputs[key].to(self.device)
forward_start = time.perf_counter()
with torch.no_grad():
outputs = self.bert_model(**inputs, output_hidden_states=True)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
hidden = outputs["hidden_states"][-3].detach().cpu()
scatter_start = time.perf_counter()
for batch_index, task in enumerate(batch):
try:
text_len = len(task.word2ph)
if text_len != len(task.norm_text):
raise AssertionError(
f"word2ph/text length mismatch: task={task.task_id} word2ph={text_len} text={len(task.norm_text)}"
)
seq_len = int(attention_mask_cpu[batch_index].sum().item())
char_features = hidden[batch_index, 1 : seq_len - 1]
if char_features.shape[0] != text_len:
raise AssertionError(
f"bert token length mismatch: task={task.task_id} token_len={char_features.shape[0]} text_len={text_len}"
)
phone_level_feature = []
for char_index, repeat_count in enumerate(task.word2ph):
phone_level_feature.append(char_features[char_index].repeat(repeat_count, 1))
task.result_feature = torch.cat(phone_level_feature, dim=0).T
task.profile = {
"bert_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]),
"bert_forward_ms": float(forward_ms),
"bert_tokenize_ms": float(tokenize_ms),
"bert_scatter_ms": 0.0,
"bert_calls": 1.0,
"bert_stage_slots": float(limiter_stats["slots"]),
"bert_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
"bert_batch_size": float(len(batch)),
"bert_batch_tokens": float(batch_tokens),
}
except Exception as exc: # noqa: PERF203
task.error = exc
scatter_ms = (time.perf_counter() - scatter_start) * 1000.0
for task in batch:
if task.result_feature is not None:
task.profile["bert_scatter_ms"] = float(scatter_ms)
task.done_event.set()
def _run_loop(self) -> None:
while True:
batch = self._collect_batch()
try:
self._run_batch(batch)
except Exception as exc: # noqa: PERF203
for task in batch:
task.error = exc
task.done_event.set()
finally:
self._finalize_batch(batch)

View File

@ -0,0 +1,262 @@
import threading
import time
import uuid
from collections import deque
from dataclasses import dataclass, field
from typing import Deque, Dict, List, Tuple
import librosa
import numpy as np
import torch
REF_AUDIO_MIN_SAMPLES_16K = 48000
REF_AUDIO_MAX_SAMPLES_16K = 160000
def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor:
wav_mono = raw_audio
if wav_mono.dim() == 2 and wav_mono.shape[0] != 1:
wav_mono = wav_mono.mean(0, keepdim=True)
wav16k = wav_mono.squeeze(0).cpu().numpy()
if raw_sr != 16000:
wav16k = librosa.resample(wav16k, orig_sr=raw_sr, target_sr=16000)
if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K:
raise OSError("参考音频在3~10秒范围外请更换")
wav16k = np.ascontiguousarray(wav16k, dtype=np.float32)
if zero_wav_samples > 0:
wav16k = np.concatenate([wav16k, np.zeros(int(zero_wav_samples), dtype=np.float32)], axis=0)
return torch.from_numpy(wav16k)
def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor:
if conv1d is None:
return input_lengths.to(dtype=torch.long)
kernel_size = int(conv1d.kernel_size[0])
stride = int(conv1d.stride[0])
padding = int(conv1d.padding[0])
dilation = int(conv1d.dilation[0])
output_lengths = torch.div(
input_lengths + 2 * padding - dilation * (kernel_size - 1) - 1,
stride,
rounding_mode="floor",
) + 1
return torch.clamp(output_lengths, min=0).to(dtype=torch.long)
@dataclass
class RefSemanticTask:
raw_audio: torch.Tensor
raw_sr: int
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
created_at: float = field(default_factory=time.perf_counter)
done_event: threading.Event = field(default_factory=threading.Event)
result_prompt_semantic: torch.Tensor | None = None
error: Exception | None = None
profile: Dict[str, float] = field(default_factory=dict)
class PrepareRefSemanticBatchWorker:
def __init__(
self,
ssl_model,
vits_model,
device,
is_half: bool,
zero_wav_samples: int,
stage_limiter=None,
batch_window_ms: int = 5,
max_batch_items: int = 8,
max_batch_samples: int = 960000,
):
self.ssl_model = ssl_model
self.vits_model = vits_model
self.device = device
self.is_half = bool(is_half)
self.zero_wav_samples = max(0, int(zero_wav_samples))
self.stage_limiter = stage_limiter
self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0)
self.max_batch_items = max(1, int(max_batch_items))
self.max_batch_samples = max(REF_AUDIO_MIN_SAMPLES_16K + self.zero_wav_samples, int(max_batch_samples))
self.condition = threading.Condition()
self.pending_tasks: Deque[RefSemanticTask] = deque()
self.pending_peak = 0
self.total_submitted = 0
self.total_finished = 0
self.total_batches = 0
self.active_batch_size = 0
self.active_batch_peak = 0
self.active_batch_samples = 0
self.active_batch_samples_peak = 0
self.worker_thread = threading.Thread(
target=self._run_loop,
name="prepare-ref-semantic-batch-worker",
daemon=True,
)
self.worker_thread.start()
def _estimate_task_samples(self, task: RefSemanticTask) -> int:
raw_len = int(task.raw_audio.shape[-1]) if task.raw_audio.dim() > 0 else 0
base = int(round(raw_len * 16000.0 / max(1, int(task.raw_sr))))
return max(REF_AUDIO_MIN_SAMPLES_16K, base) + self.zero_wav_samples
def submit(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]:
task = RefSemanticTask(raw_audio=raw_audio, raw_sr=int(raw_sr))
with self.condition:
self.pending_tasks.append(task)
self.total_submitted += 1
if len(self.pending_tasks) > self.pending_peak:
self.pending_peak = len(self.pending_tasks)
self.condition.notify_all()
task.done_event.wait()
if task.error is not None:
raise task.error
assert task.result_prompt_semantic is not None
return task.result_prompt_semantic, dict(task.profile)
def snapshot(self) -> Dict[str, int]:
with self.condition:
return {
"pending": len(self.pending_tasks),
"pending_peak": self.pending_peak,
"total_submitted": self.total_submitted,
"total_finished": self.total_finished,
"total_batches": self.total_batches,
"active_batch_size": self.active_batch_size,
"active_batch_peak": self.active_batch_peak,
"active_batch_samples": self.active_batch_samples,
"active_batch_samples_peak": self.active_batch_samples_peak,
"batch_window_ms": int(self.batch_window_s * 1000.0),
"max_batch_items": self.max_batch_items,
"max_batch_samples": self.max_batch_samples,
}
def _collect_batch(self) -> List[RefSemanticTask]:
with self.condition:
while not self.pending_tasks:
self.condition.wait()
batch: List[RefSemanticTask] = [self.pending_tasks.popleft()]
batch_samples = self._estimate_task_samples(batch[0])
deadline = time.perf_counter() + self.batch_window_s
while len(batch) < self.max_batch_items:
remaining = deadline - time.perf_counter()
if remaining <= 0:
break
if not self.pending_tasks:
self.condition.wait(timeout=remaining)
continue
next_task = self.pending_tasks[0]
next_samples = self._estimate_task_samples(next_task)
if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples:
break
batch.append(self.pending_tasks.popleft())
batch_samples += next_samples
self.active_batch_size = len(batch)
self.active_batch_samples = batch_samples
if self.active_batch_size > self.active_batch_peak:
self.active_batch_peak = self.active_batch_size
if self.active_batch_samples > self.active_batch_samples_peak:
self.active_batch_samples_peak = self.active_batch_samples
return batch
def _finalize_batch(self, batch: List[RefSemanticTask]) -> None:
with self.condition:
self.active_batch_size = 0
self.active_batch_samples = 0
self.total_batches += 1
self.total_finished += len(batch)
def _get_hidden_lengths(self, attention_mask: torch.Tensor, hidden_length: int) -> torch.Tensor:
model = self.ssl_model.model
if hasattr(model, "_get_feature_vector_attention_mask"):
feature_mask = model._get_feature_vector_attention_mask(hidden_length, attention_mask)
return feature_mask.to(dtype=torch.long).sum(dim=1)
raw_lengths = attention_mask.to(dtype=torch.long).sum(dim=1)
if hasattr(model, "_get_feat_extract_output_lengths"):
return model._get_feat_extract_output_lengths(raw_lengths).to(dtype=torch.long)
return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device)
@torch.inference_mode()
def _run_batch(self, batch: List[RefSemanticTask]) -> None:
batch_started = time.perf_counter()
prepared_start = time.perf_counter()
prepared_wavs = [
prepare_prompt_semantic_wav16k(task.raw_audio, int(task.raw_sr), self.zero_wav_samples) for task in batch
]
cpu_prepare_ms = (time.perf_counter() - prepared_start) * 1000.0
wav_lengths = torch.tensor([int(wav.shape[0]) for wav in prepared_wavs], dtype=torch.long)
batch_samples = int(wav_lengths.sum().item())
max_wav_len = int(wav_lengths.max().item())
input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32)
attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long)
for batch_index, wav in enumerate(prepared_wavs):
wav_len = int(wav.shape[0])
input_values_cpu[batch_index, :wav_len] = wav
attention_mask_cpu[batch_index, :wav_len] = 1
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
if self.stage_limiter is None:
input_values = input_values_cpu.to(self.device)
attention_mask = attention_mask_cpu.to(self.device)
if self.is_half:
input_values = input_values.half()
forward_start = time.perf_counter()
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
codes = self.vits_model.extract_latent(hubert_feature)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
else:
with self.stage_limiter.enter() as limiter_stats:
input_values = input_values_cpu.to(self.device)
attention_mask = attention_mask_cpu.to(self.device)
if self.is_half:
input_values = input_values.half()
forward_start = time.perf_counter()
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
codes = self.vits_model.extract_latent(hubert_feature)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None))
scatter_start = time.perf_counter()
for batch_index, task in enumerate(batch):
try:
code_len = int(code_lengths[batch_index].item())
task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone()
task.profile = {
"prompt_semantic_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]),
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
"prompt_semantic_forward_ms": float(forward_ms),
"prompt_semantic_scatter_ms": 0.0,
"prompt_semantic_calls": 1.0,
"prompt_semantic_stage_slots": float(limiter_stats["slots"]),
"prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
"prompt_semantic_batch_size": float(len(batch)),
"prompt_semantic_batch_samples": float(batch_samples),
}
except Exception as exc: # noqa: PERF203
task.error = exc
scatter_ms = (time.perf_counter() - scatter_start) * 1000.0
for task in batch:
if task.result_prompt_semantic is not None:
task.profile["prompt_semantic_scatter_ms"] = float(scatter_ms)
task.done_event.set()
def _run_loop(self) -> None:
while True:
batch = self._collect_batch()
try:
self._run_batch(batch)
except Exception as exc: # noqa: PERF203
for task in batch:
task.error = exc
task.done_event.set()
finally:
self._finalize_batch(batch)

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from concurrent.futures import Future
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
import time import time
@ -123,31 +124,58 @@ def prepare_request_state(
prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang)
text = spec.text.strip("\n") text = spec.text.strip("\n")
_sync_device(device) prompt_text_profile: Dict[str, float] = {}
prompt_text_features_start = time.perf_counter() text_features_profile: Dict[str, float] = {}
prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features(prompt_text, spec.prompt_lang) text_feature_pair_start = time.perf_counter()
_sync_device(device) prompt_future: Future | None = None
prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0
def _extract_prompt_features():
_sync_device(device)
prompt_start = time.perf_counter()
result = tts.extract_text_features(prompt_text, spec.prompt_lang, profile=prompt_text_profile)
_sync_device(device)
return result, (time.perf_counter() - prompt_start) * 1000.0
if getattr(tts, "prepare_text_cpu_executor", None) is not None:
prompt_future = tts.prepare_text_cpu_executor.submit(_extract_prompt_features)
_sync_device(device) _sync_device(device)
text_features_start = time.perf_counter() text_features_start = time.perf_counter()
phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang) phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang, profile=text_features_profile)
_sync_device(device) _sync_device(device)
text_features_ms = (time.perf_counter() - text_features_start) * 1000.0 text_features_ms = (time.perf_counter() - text_features_start) * 1000.0
if prompt_future is None:
_sync_device(device)
prompt_text_features_start = time.perf_counter()
prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features(
prompt_text, spec.prompt_lang, profile=prompt_text_profile
)
_sync_device(device)
prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0
prompt_text_profile["parallel_future_wait_ms"] = 0.0
else:
prompt_wait_start = time.perf_counter()
(prompt_phones, prompt_bert_features, prompt_norm_text), prompt_text_features_ms = prompt_future.result()
prompt_text_profile["parallel_future_wait_ms"] = (time.perf_counter() - prompt_wait_start) * 1000.0
text_feature_pair_ms = (time.perf_counter() - text_feature_pair_start) * 1000.0
if phones is None: if phones is None:
raise ValueError(f"{spec.request_id} text preprocessing returned no phones") raise ValueError(f"{spec.request_id} text preprocessing returned no phones")
_sync_device(device) _sync_device(device)
prompt_semantic_start = time.perf_counter() ref_audio_bundle_start = time.perf_counter()
prompt_semantic = tts.extract_prompt_semantic(str(spec.ref_audio_path)).long() ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path))
prompt_semantic = ref_audio_bundle["prompt_semantic"].long()
spec_audio, audio_16k = ref_audio_bundle["refer_spec"]
raw_audio = ref_audio_bundle["raw_audio"]
raw_sr = int(ref_audio_bundle["raw_sr"])
_sync_device(device) _sync_device(device)
prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 ref_audio_bundle_ms = (time.perf_counter() - ref_audio_bundle_start) * 1000.0
bundle_profile = ref_audio_bundle.get("profile", {})
_sync_device(device) prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms))
ref_spec_start = time.perf_counter() ref_spec_ms = float(bundle_profile.get("ref_spec_ms", 0.0))
spec_audio, audio_16k, raw_audio, raw_sr = tts.extract_ref_spec(str(spec.ref_audio_path)) audio_load_ms = float(bundle_profile.get("audio_load_ms", 0.0))
_sync_device(device)
ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0
_sync_device(device) _sync_device(device)
tensorize_start = time.perf_counter() tensorize_start = time.perf_counter()
@ -164,8 +192,43 @@ def prepare_request_state(
prepare_profile = { prepare_profile = {
"prompt_text_features_ms": prompt_text_features_ms, "prompt_text_features_ms": prompt_text_features_ms,
"text_features_ms": text_features_ms, "text_features_ms": text_features_ms,
"prompt_text_bert_wait_ms": float(prompt_text_profile.get("bert_wait_ms", 0.0)),
"prompt_text_bert_forward_ms": float(prompt_text_profile.get("bert_forward_ms", 0.0)),
"prompt_text_bert_tokenize_ms": float(prompt_text_profile.get("bert_tokenize_ms", 0.0)),
"prompt_text_bert_scatter_ms": float(prompt_text_profile.get("bert_scatter_ms", 0.0)),
"prompt_text_bert_calls": float(prompt_text_profile.get("bert_calls", 0.0)),
"prompt_text_bert_stage_slots": float(prompt_text_profile.get("bert_stage_slots", 0.0)),
"prompt_text_bert_stage_inflight_peak": float(prompt_text_profile.get("bert_stage_inflight_peak", 0.0)),
"prompt_text_bert_batch_size_peak": float(prompt_text_profile.get("bert_batch_size_peak", 0.0)),
"prompt_text_bert_batch_tokens_peak": float(prompt_text_profile.get("bert_batch_tokens_peak", 0.0)),
"prompt_text_parallel_future_wait_ms": float(prompt_text_profile.get("parallel_future_wait_ms", 0.0)),
"text_bert_wait_ms": float(text_features_profile.get("bert_wait_ms", 0.0)),
"text_bert_forward_ms": float(text_features_profile.get("bert_forward_ms", 0.0)),
"text_bert_tokenize_ms": float(text_features_profile.get("bert_tokenize_ms", 0.0)),
"text_bert_scatter_ms": float(text_features_profile.get("bert_scatter_ms", 0.0)),
"text_bert_calls": float(text_features_profile.get("bert_calls", 0.0)),
"text_bert_stage_slots": float(text_features_profile.get("bert_stage_slots", 0.0)),
"text_bert_stage_inflight_peak": float(text_features_profile.get("bert_stage_inflight_peak", 0.0)),
"text_bert_batch_size_peak": float(text_features_profile.get("bert_batch_size_peak", 0.0)),
"text_bert_batch_tokens_peak": float(text_features_profile.get("bert_batch_tokens_peak", 0.0)),
"text_feature_pair_ms": text_feature_pair_ms,
"text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)),
"audio_load_ms": audio_load_ms,
"audio_stage_wait_ms": float(bundle_profile.get("audio_stage_wait_ms", 0.0)),
"audio_stage_slots": float(bundle_profile.get("audio_stage_slots", 0.0)),
"audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)),
"prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
"prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)),
"prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)),
"prompt_semantic_stage_slots": float(bundle_profile.get("prompt_semantic_stage_slots", 0.0)),
"prompt_semantic_stage_inflight_peak": float(bundle_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
"prompt_semantic_batch_size": float(bundle_profile.get("prompt_semantic_batch_size", 0.0)),
"prompt_semantic_batch_samples": float(bundle_profile.get("prompt_semantic_batch_samples", 0.0)),
"ref_spec_wait_ms": float(bundle_profile.get("ref_spec_wait_ms", 0.0)),
"ref_spec_ms": ref_spec_ms, "ref_spec_ms": ref_spec_ms,
"ref_audio_bundle_ms": ref_audio_bundle_ms,
"tensorize_ms": tensorize_ms, "tensorize_ms": tensorize_ms,
"total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0,
"wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0, "wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0,
@ -186,7 +249,7 @@ def prepare_request_state(
prompt_semantic=prompt_semantic, prompt_semantic=prompt_semantic,
refer_spec=(spec_audio, audio_16k), refer_spec=(spec_audio, audio_16k),
raw_audio=raw_audio, raw_audio=raw_audio,
raw_sr=int(raw_sr), raw_sr=raw_sr,
top_k=spec.top_k, top_k=spec.top_k,
top_p=spec.top_p, top_p=spec.top_p,
temperature=spec.temperature, temperature=spec.temperature,
@ -409,9 +472,18 @@ def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor:
return F.pad(mask, (pad_len, 0), value=True) return F.pad(mask, (pad_len, 0), value=True)
def _fit_decode_mask_length(mask: torch.Tensor, target_len: int) -> torch.Tensor:
if mask.shape[-1] > target_len:
return mask[:, :, :, -target_len:]
if mask.shape[-1] < target_len:
return _pad_decode_mask_left(mask, target_len)
return mask
def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> torch.Tensor: def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> torch.Tensor:
expected_mask_len = running_request.k_cache[0].shape[1] + 1
if running_request.decode_attn_mask is not None: if running_request.decode_attn_mask is not None:
return running_request.decode_attn_mask return _fit_decode_mask_length(running_request.decode_attn_mask, expected_mask_len)
current_mask_len = running_request.k_cache[0].shape[1] + 1 current_mask_len = running_request.k_cache[0].shape[1] + 1
return torch.zeros( return torch.zeros(
(1, 1, 1, current_mask_len), (1, 1, 1, current_mask_len),
@ -481,17 +553,19 @@ def run_prefill_step(
real_kv_len = int(active_batch.x_lens[batch_index].item()) + prefix_len real_kv_len = int(active_batch.x_lens[batch_index].item()) + prefix_len
request_k_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in k_cache] request_k_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in k_cache]
request_v_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in v_cache] request_v_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in v_cache]
request_decode_attn_mask = None
if decode_attn_mask is not None:
request_decode_attn_mask = decode_attn_mask[batch_index : batch_index + 1].clone()
request_decode_attn_mask = _fit_decode_mask_length(request_decode_attn_mask, real_kv_len + 1)
if not request_decode_attn_mask.any().item():
request_decode_attn_mask = None
running_requests.append( running_requests.append(
T2SRunningRequest( T2SRunningRequest(
state=state, state=state,
y_sequence=new_history, y_sequence=new_history,
prefix_len=prefix_len, prefix_len=prefix_len,
decode_attn_mask=( decode_attn_mask=request_decode_attn_mask,
None
if decode_attn_mask is None
else decode_attn_mask[batch_index : batch_index + 1].clone()
),
k_cache=request_k_cache, k_cache=request_k_cache,
v_cache=request_v_cache, v_cache=request_v_cache,
step_idx=1, step_idx=1,
@ -603,6 +677,9 @@ def run_decode_step_for_running(
batch_index : batch_index + 1, :, :, -current_decode_mask_len: batch_index : batch_index + 1, :, :, -current_decode_mask_len:
] ]
next_decode_attn_mask = F.pad(current_decode_attn_mask, (0, 1), value=False) next_decode_attn_mask = F.pad(current_decode_attn_mask, (0, 1), value=False)
next_decode_attn_mask = _fit_decode_mask_length(next_decode_attn_mask, real_next_kv_len + 1)
if not next_decode_attn_mask.any().item():
next_decode_attn_mask = None
next_running.append( next_running.append(
T2SRunningRequest( T2SRunningRequest(
state=running_request.state, state=running_request.state,

View File

@ -261,8 +261,9 @@ class SchedulerDebugWorker:
self.tts = tts self.tts = tts
self.max_steps = max_steps self.max_steps = max_steps
self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0 self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0
self.prepare_lock = threading.Lock()
self.condition = threading.Condition() self.condition = threading.Condition()
self.prepare_inflight = 0
self.prepare_peak_inflight = 0
self.pending_jobs: List[SchedulerPendingJob] = [] self.pending_jobs: List[SchedulerPendingJob] = []
self.running_requests: List[T2SRunningRequest] = [] self.running_requests: List[T2SRunningRequest] = []
self.job_map: dict[str, SchedulerPendingJob] = {} self.job_map: dict[str, SchedulerPendingJob] = {}
@ -282,8 +283,20 @@ class SchedulerDebugWorker:
pass pass
def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState: def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState:
with self.prepare_lock: with self.condition:
return prepare_request_state(self.tts, spec) self.prepare_inflight += 1
prepare_inflight_on_enter = self.prepare_inflight
if self.prepare_inflight > self.prepare_peak_inflight:
self.prepare_peak_inflight = self.prepare_inflight
prepare_peak_inflight = self.prepare_peak_inflight
try:
state = prepare_request_state(self.tts, spec)
state.prepare_profile["worker_prepare_inflight_on_enter"] = float(prepare_inflight_on_enter)
state.prepare_profile["worker_prepare_peak_inflight"] = float(prepare_peak_inflight)
return state
finally:
with self.condition:
self.prepare_inflight = max(0, self.prepare_inflight - 1)
def submit( def submit(
self, self,
@ -363,9 +376,28 @@ class SchedulerDebugWorker:
def get_state(self) -> dict: def get_state(self) -> dict:
with self.condition: with self.condition:
bert_stage = self.tts.prepare_bert_stage_limiter.snapshot()
ref_audio_stage = self.tts.prepare_ref_audio_stage_limiter.snapshot()
bert_batch_worker = (
None
if self.tts.prepare_bert_batch_worker is None
else self.tts.prepare_bert_batch_worker.snapshot()
)
ref_semantic_batch_worker = (
None
if self.tts.prepare_ref_semantic_batch_worker is None
else self.tts.prepare_ref_semantic_batch_worker.snapshot()
)
return { return {
"pending_jobs": len(self.pending_jobs), "pending_jobs": len(self.pending_jobs),
"running_requests": len(self.running_requests), "running_requests": len(self.running_requests),
"prepare_inflight": self.prepare_inflight,
"prepare_peak_inflight": self.prepare_peak_inflight,
"prepare_text_cpu_workers": int(getattr(self.tts, "prepare_text_cpu_workers", 0)),
"prepare_bert_stage": bert_stage,
"prepare_bert_batch_worker": bert_batch_worker,
"prepare_ref_audio_stage": ref_audio_stage,
"prepare_ref_semantic_batch_worker": ref_semantic_batch_worker,
"tracked_jobs": len(self.job_map), "tracked_jobs": len(self.job_map),
"total_submitted": self.total_submitted, "total_submitted": self.total_submitted,
"total_finished": self.total_finished, "total_finished": self.total_finished,
@ -907,10 +939,36 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
{ {
"X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}",
"X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}",
"X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}",
"X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Bert-Batch-Size-Peak": str(
int(prepare_profile.get("prompt_text_bert_batch_size_peak", 0.0))
),
"X-Prepare-Target-Bert-Batch-Size-Peak": str(
int(prepare_profile.get("text_bert_batch_size_peak", 0.0))
),
"X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}",
"X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))),
"X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}",
"X-Prepare-Audio-Stage-Wait-Ms": f"{float(prepare_profile.get('audio_stage_wait_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}", "X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Semantic-Wait-Ms": f"{float(prepare_profile.get('prompt_semantic_wait_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Semantic-CPU-Ms": f"{float(prepare_profile.get('prompt_semantic_cpu_prepare_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Semantic-Forward-Ms": f"{float(prepare_profile.get('prompt_semantic_forward_ms', 0.0)):.3f}",
"X-Prepare-Prompt-Semantic-Batch-Size": str(
int(prepare_profile.get("prompt_semantic_batch_size", 0.0))
),
"X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}", "X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}",
"X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}",
"X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}",
"X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}",
"X-Prepare-Profile-Wall-Ms": f"{float(prepare_profile.get('wall_total_ms', 0.0)):.3f}", "X-Prepare-Profile-Wall-Ms": f"{float(prepare_profile.get('wall_total_ms', 0.0)):.3f}",
"X-Prepare-Inflight-On-Enter": str(
int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0))
),
"X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))),
} }
) )
return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers) return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers)