diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 2fd0df35..9c259662 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -5,6 +5,7 @@ import random import sys import time import traceback +from concurrent.futures import ThreadPoolExecutor from copy import deepcopy import torchaudio @@ -33,7 +34,12 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer from tools.audio_sr import AP_BWE from tools.i18n.i18n import I18nAuto, scan_language_list from TTS_infer_pack.text_segmentation_method import splits -from TTS_infer_pack.TextPreprocessor import TextPreprocessor +from TTS_infer_pack.TextPreprocessor import TextPreprocessor, StageLimiter +from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker +from TTS_infer_pack.prepare_ref_semantic_batch_worker import ( + PrepareRefSemanticBatchWorker, + prepare_prompt_semantic_wav16k, +) from sv import SV resample_transform_dict = {} @@ -442,11 +448,56 @@ class TTS: "upsample_rate": None, "overlapped_len": None, } + self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1"))) + self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "2"))) + self.prepare_bert_batch_worker = None + self.prepare_ref_semantic_batch_worker = None + default_text_cpu_workers = 16 + self.prepare_text_cpu_workers = max( + 0, + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", str(default_text_cpu_workers))), + ) + self.prepare_text_cpu_executor = None + if self.prepare_text_cpu_workers > 0: + self.prepare_text_cpu_executor = ThreadPoolExecutor( + max_workers=self.prepare_text_cpu_workers, + thread_name_prefix="prepare-text-cpu", + ) self._init_models() + if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0": + self.prepare_bert_batch_worker = PrepareBertBatchWorker( + bert_model=self.bert_model, + tokenizer=self.bert_tokenizer, + device=self.configs.device, + stage_limiter=self.prepare_bert_stage_limiter, + batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_WINDOW_MS", "5")), + max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_ITEMS", "16")), + max_batch_tokens=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_TOKENS", "4096")), + ) + if os.environ.get("GPTSOVITS_PREPARE_REF_BATCHING", "0") != "0": + ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_SAMPLES") + if ref_max_batch_samples is None: + ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_FRAMES", "960000") + self.prepare_ref_semantic_batch_worker = PrepareRefSemanticBatchWorker( + ssl_model=self.cnhuhbert_model, + vits_model=self.vits_model, + device=self.configs.device, + is_half=self.configs.is_half, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + stage_limiter=self.prepare_ref_audio_stage_limiter, + batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_WINDOW_MS", "5")), + max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_ITEMS", "8")), + max_batch_samples=int(ref_max_batch_samples), + ) + self.text_preprocessor: TextPreprocessor = TextPreprocessor( - self.bert_model, self.bert_tokenizer, self.configs.device + self.bert_model, + self.bert_tokenizer, + self.configs.device, + bert_stage_limiter=self.prepare_bert_stage_limiter, + bert_batch_worker=self.prepare_bert_batch_worker, ) self.prompt_cache: dict = { @@ -755,47 +806,52 @@ class TTS: Args: ref_audio_path: str, the path of the reference audio. """ - self._set_prompt_semantic(ref_audio_path) - self._set_ref_spec(ref_audio_path) + bundle = self.extract_ref_audio_bundle(ref_audio_path) + if self.prompt_cache["refer_spec"] in [[], None]: + self.prompt_cache["refer_spec"] = [bundle["refer_spec"]] + else: + self.prompt_cache["refer_spec"][0] = bundle["refer_spec"] + self.prompt_cache["prompt_semantic"] = bundle["prompt_semantic"] + self.prompt_cache["raw_audio"] = bundle["raw_audio"] + self.prompt_cache["raw_sr"] = bundle["raw_sr"] self._set_ref_audio_path(ref_audio_path) - def extract_prompt_semantic(self, ref_wav_path: str): - zero_wav = np.zeros( - int(self.configs.sampling_rate * 0.3), - dtype=np.float16 if self.configs.is_half else np.float32, - ) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: - raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - wav16k = wav16k.to(self.configs.device) - zero_wav_torch = zero_wav_torch.to(self.configs.device) - if self.configs.is_half: - wav16k = wav16k.half() - zero_wav_torch = zero_wav_torch.half() - - wav16k = torch.cat([wav16k, zero_wav_torch]) - hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose( - 1, 2 - ) # .float() - codes = self.vits_model.extract_latent(hubert_feature) - - prompt_semantic = codes[0, 0].to(self.configs.device) - return prompt_semantic - - def extract_ref_spec(self, ref_audio_path: str): + def _load_ref_audio_raw(self, ref_audio_path: str): raw_audio, raw_sr = torchaudio.load(ref_audio_path) - raw_audio = raw_audio.to(self.configs.device).float() + return raw_audio.float(), int(raw_sr) + + @torch.inference_mode() + def _extract_prompt_semantic_from_prepared_wav16k(self, wav16k: torch.Tensor): + wav16k = wav16k.to(self.configs.device) + if self.configs.is_half: + wav16k = wav16k.half() + hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) + codes = self.vits_model.extract_latent(hubert_feature) + return codes[0, 0].to(self.configs.device) + + @torch.inference_mode() + def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + wav16k = prepare_prompt_semantic_wav16k( + raw_audio=raw_audio, + raw_sr=raw_sr, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + ) + return self._extract_prompt_semantic_from_prepared_wav16k(wav16k) + + def extract_prompt_semantic(self, ref_wav_path: str): + raw_audio, raw_sr = self._load_ref_audio_raw(ref_wav_path) + return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr) + + def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + raw_audio_device = raw_audio.to(self.configs.device).float() if raw_sr != self.configs.sampling_rate: - audio = raw_audio.to(self.configs.device) + audio = raw_audio_device if audio.shape[0] == 2: audio = audio.mean(0).unsqueeze(0) audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device) else: - audio = raw_audio.to(self.configs.device) + audio = raw_audio_device if audio.shape[0] == 2: audio = audio.mean(0).unsqueeze(0) @@ -820,8 +876,141 @@ class TTS: audio = None return spec, audio, raw_audio, raw_sr - def extract_text_features(self, text: str, language: str): - return self.text_preprocessor.segment_and_extract_feature_for_text(text, language, self.configs.version) + def extract_ref_spec(self, ref_audio_path: str): + raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) + return self._extract_ref_spec_from_raw(raw_audio, raw_sr) + + def extract_ref_audio_bundle(self, ref_audio_path: str): + load_start = time.perf_counter() + raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) + load_ms = (time.perf_counter() - load_start) * 1000.0 + if self.prepare_ref_semantic_batch_worker is None: + with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats: + prompt_semantic_start = time.perf_counter() + prompt_semantic = self._extract_prompt_semantic_from_raw(raw_audio, raw_sr) + prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 + ref_spec_start = time.perf_counter() + refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + audio_stage_wait_ms = float(limiter_stats["wait_ms"]) + audio_stage_slots = float(limiter_stats["slots"]) + audio_stage_inflight_peak = float(limiter_stats["peak_inflight"]) + prompt_semantic_profile = { + "prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]), + "prompt_semantic_cpu_prepare_ms": 0.0, + "prompt_semantic_forward_ms": prompt_semantic_ms, + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_stage_slots": float(limiter_stats["slots"]), + "prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]), + "prompt_semantic_batch_size": 1.0, + "prompt_semantic_batch_samples": 0.0, + } + ref_spec_wait_ms = 0.0 + return { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": load_ms, + "audio_stage_wait_ms": audio_stage_wait_ms, + "audio_stage_slots": audio_stage_slots, + "audio_stage_inflight_peak": audio_stage_inflight_peak, + "prompt_semantic_ms": prompt_semantic_ms, + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + ), + "prompt_semantic_forward_ms": float( + prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0) + ), + "prompt_semantic_scatter_ms": float( + prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0) + ), + "prompt_semantic_stage_slots": float( + prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0) + ), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float( + prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) + ), + "ref_spec_wait_ms": ref_spec_wait_ms, + "ref_spec_ms": ref_spec_ms, + "bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms, + }, + } + + prompt_semantic_profile = { + "prompt_semantic_wait_ms": 0.0, + "prompt_semantic_cpu_prepare_ms": 0.0, + "prompt_semantic_forward_ms": 0.0, + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_stage_slots": 0.0, + "prompt_semantic_stage_inflight_peak": 0.0, + "prompt_semantic_batch_size": 1.0, + "prompt_semantic_batch_samples": 0.0, + } + if self.prepare_ref_semantic_batch_worker is not None: + prompt_semantic, worker_profile = self.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr) + prompt_semantic_profile.update(worker_profile) + prompt_semantic_ms = ( + float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) + ) + with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats: + ref_spec_start = time.perf_counter() + refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float( + ref_spec_limiter_stats["wait_ms"] + ) + audio_stage_slots = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + float(ref_spec_limiter_stats["slots"]), + ) + audio_stage_inflight_peak = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + float(ref_spec_limiter_stats["peak_inflight"]), + ) + return { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": load_ms, + "audio_stage_wait_ms": audio_stage_wait_ms, + "audio_stage_slots": audio_stage_slots, + "audio_stage_inflight_peak": audio_stage_inflight_peak, + "prompt_semantic_ms": prompt_semantic_ms, + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + ), + "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float( + prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) + ), + "ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]), + "ref_spec_ms": ref_spec_ms, + "bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms, + }, + } + + def extract_text_features(self, text: str, language: str, profile: dict | None = None): + return self.text_preprocessor.segment_and_extract_feature_for_text( + text, language, self.configs.version, profile=profile + ) def _set_ref_audio_path(self, ref_audio_path): self.prompt_cache["ref_audio_path"] = ref_audio_path diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 283e91c3..15b3c322 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -1,6 +1,8 @@ import os import sys import threading +import time +from contextlib import contextmanager from tqdm import tqdm @@ -16,6 +18,7 @@ from text.cleaner import clean_text from text import cleaned_text_to_sequence from transformers import AutoModelForMaskedLM, AutoTokenizer from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method +from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker from tools.i18n.i18n import I18nAuto, scan_language_list @@ -49,12 +52,60 @@ def merge_short_text_in_array(texts: str, threshold: int) -> list: return result +class StageLimiter: + def __init__(self, slots: int): + self.slots = max(1, int(slots)) + self.semaphore = threading.BoundedSemaphore(self.slots) + self.lock = threading.Lock() + self.inflight = 0 + self.peak_inflight = 0 + + @contextmanager + def enter(self): + wait_start = time.perf_counter() + self.semaphore.acquire() + wait_ms = (time.perf_counter() - wait_start) * 1000.0 + with self.lock: + self.inflight += 1 + current_inflight = self.inflight + if current_inflight > self.peak_inflight: + self.peak_inflight = current_inflight + peak_inflight = self.peak_inflight + try: + yield { + "wait_ms": wait_ms, + "inflight": current_inflight, + "peak_inflight": peak_inflight, + "slots": self.slots, + } + finally: + with self.lock: + self.inflight = max(0, self.inflight - 1) + self.semaphore.release() + + def snapshot(self) -> Dict[str, int]: + with self.lock: + return { + "slots": self.slots, + "inflight": self.inflight, + "peak_inflight": self.peak_inflight, + } + + class TextPreprocessor: - def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device): + def __init__( + self, + bert_model: AutoModelForMaskedLM, + tokenizer: AutoTokenizer, + device: torch.device, + bert_stage_limiter: StageLimiter | None = None, + bert_batch_worker: PrepareBertBatchWorker | None = None, + ): self.bert_model = bert_model self.tokenizer = tokenizer self.device = device - self.bert_lock = threading.RLock() + self.bert_stage_limiter = bert_stage_limiter + self.bert_batch_worker = bert_batch_worker def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: print(f"############ {i18n('切分文本')} ############") @@ -115,86 +166,136 @@ class TextPreprocessor: return texts def segment_and_extract_feature_for_text( - self, text: str, language: str, version: str = "v1" + self, text: str, language: str, version: str = "v1", profile: Dict | None = None ) -> Tuple[list, torch.Tensor, str]: - return self.get_phones_and_bert(text, language, version) + return self.get_phones_and_bert(text, language, version, profile=profile) - def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False): - with self.bert_lock: - text = re.sub(r' {2,}', ' ', text) - textlist = [] - langlist = [] - if language == "all_zh": - for tmp in LangSegmenter.getTexts(text,"zh"): + def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]: + textlist = [] + langlist = [] + if language == "all_zh": + for tmp in LangSegmenter.getTexts(text, "zh"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_yue": + for tmp in LangSegmenter.getTexts(text, "zh"): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ja": + for tmp in LangSegmenter.getTexts(text, "ja"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ko": + for tmp in LangSegmenter.getTexts(text, "ko"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "en": + langlist.append("en") + textlist.append(text) + elif language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if langlist: + same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or ( + tmp["lang"] != "en" and langlist[-1] != "en" + ) + if same_group: + textlist[-1] += tmp["text"] + continue + if tmp["lang"] == "en": langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_yue": - for tmp in LangSegmenter.getTexts(text,"zh"): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_ja": - for tmp in LangSegmenter.getTexts(text,"ja"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_ko": - for tmp in LangSegmenter.getTexts(text,"ko"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "en": - langlist.append("en") - textlist.append(text) - elif language == "auto": - for tmp in LangSegmenter.getTexts(text): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "auto_yue": - for tmp in LangSegmenter.getTexts(text): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - else: - for tmp in LangSegmenter.getTexts(text): - if langlist: - if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): - textlist[-1] += tmp["text"] - continue - if tmp["lang"] == "en": - langlist.append(tmp["lang"]) - else: - # 因无法区别中日韩文汉字,以用户输入为准 - langlist.append(language) - textlist.append(tmp["text"]) - # print(textlist) - # print(langlist) - phones_list = [] - bert_list = [] - norm_text_list = [] - for i in range(len(textlist)): - lang = langlist[i] - phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version) - bert = self.get_bert_inf(phones, word2ph, norm_text, lang) - phones_list.append(phones) - norm_text_list.append(norm_text) - bert_list.append(bert) - bert = torch.cat(bert_list, dim=1) - phones = sum(phones_list, []) - norm_text = "".join(norm_text_list) + else: + langlist.append(language) + textlist.append(tmp["text"]) + return textlist, langlist - if not final and len(phones) < 6: - return self.get_phones_and_bert("." + text, language, version, final=True) + def get_phones_and_bert( + self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None + ): + text = re.sub(r' {2,}', ' ', text) + textlist, langlist = self._split_text_by_language(text, language) + phones_list = [] + bert_list = [] + norm_text_list = [] + for segment_text, segment_lang in zip(textlist, langlist): + phones, word2ph, norm_text = self.clean_text_inf(segment_text, segment_lang, version) + bert = self.get_bert_inf(phones, word2ph, norm_text, segment_lang, profile=profile) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = "".join(norm_text_list) - return phones, bert, norm_text + if not final and len(phones) < 6: + return self.get_phones_and_bert("." + text, language, version, final=True, profile=profile) - def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor: - with torch.no_grad(): - inputs = self.tokenizer(text, return_tensors="pt") - for i in inputs: - inputs[i] = inputs[i].to(self.device) - res = self.bert_model(**inputs, output_hidden_states=True) - res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + return phones, bert, norm_text + + def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None: + if profile is None: + return + profile[key] = float(profile.get(key, 0.0)) + float(value) + + def _update_profile_peak(self, profile: Dict | None, key: str, value: float) -> None: + if profile is None: + return + profile[key] = float(max(float(profile.get(key, 0.0)), float(value))) + + def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor: + if self.bert_batch_worker is not None: + feature, worker_profile = self.bert_batch_worker.submit(text, word2ph) + self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0)) + self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0)) + self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0)) + self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0)) + self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0)) + self._update_profile_peak( + profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0) + ) + self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0)) + self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0)) + if profile is not None: + profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0)) + return feature + + limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0} + if self.bert_stage_limiter is None: + forward_start = time.perf_counter() + with torch.no_grad(): + inputs = self.tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(self.device) + res = self.bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + else: + with self.bert_stage_limiter.enter() as limiter_stats: + forward_start = time.perf_counter() + with torch.no_grad(): + inputs = self.tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(self.device) + res = self.bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + self._accumulate_profile(profile, "bert_wait_ms", limiter_stats["wait_ms"]) + self._accumulate_profile(profile, "bert_forward_ms", forward_ms) + self._accumulate_profile(profile, "bert_calls", 1.0) + self._update_profile_peak(profile, "bert_stage_inflight_peak", limiter_stats["peak_inflight"]) + if profile is not None: + profile["bert_stage_slots"] = float(limiter_stats["slots"]) assert len(word2ph) == len(text) phone_level_feature = [] for i in range(len(word2ph)): @@ -209,10 +310,10 @@ class TextPreprocessor: phones = cleaned_text_to_sequence(phones, version) return phones, word2ph, norm_text - def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str): + def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str, profile: Dict | None = None): language = language.replace("all_", "") if language == "zh": - feature = self.get_bert_feature(norm_text, word2ph).to(self.device) + feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device) else: feature = torch.zeros( (1024, len(phones)), @@ -236,4 +337,4 @@ class TextPreprocessor: punctuations = "".join(re.escape(p) for p in punctuation) pattern = f"([{punctuations}])([{punctuations}])+" result = re.sub(pattern, r"\1", text) - return result \ No newline at end of file + return result diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py new file mode 100644 index 00000000..b1ede3d8 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py @@ -0,0 +1,197 @@ +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Deque, Dict, List, Tuple + +import torch + + +@dataclass +class BertFeatureTask: + norm_text: str + word2ph: List[int] + task_id: str = field(default_factory=lambda: uuid.uuid4().hex) + created_at: float = field(default_factory=time.perf_counter) + done_event: threading.Event = field(default_factory=threading.Event) + result_feature: torch.Tensor | None = None + error: Exception | None = None + profile: Dict[str, float] = field(default_factory=dict) + + +class PrepareBertBatchWorker: + def __init__( + self, + bert_model, + tokenizer, + device, + stage_limiter=None, + batch_window_ms: int = 5, + max_batch_items: int = 16, + max_batch_tokens: int = 4096, + ): + self.bert_model = bert_model + self.tokenizer = tokenizer + self.device = device + self.stage_limiter = stage_limiter + self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0) + self.max_batch_items = max(1, int(max_batch_items)) + self.max_batch_tokens = max(16, int(max_batch_tokens)) + + self.condition = threading.Condition() + self.pending_tasks: Deque[BertFeatureTask] = deque() + self.pending_peak = 0 + self.total_submitted = 0 + self.total_finished = 0 + self.total_batches = 0 + self.active_batch_size = 0 + self.active_batch_peak = 0 + self.worker_thread = threading.Thread(target=self._run_loop, name="prepare-bert-batch-worker", daemon=True) + self.worker_thread.start() + + def _estimate_task_tokens(self, task: BertFeatureTask) -> int: + return max(1, len(task.norm_text) + 2) + + def submit(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]: + task = BertFeatureTask(norm_text=str(norm_text), word2ph=list(word2ph)) + with self.condition: + self.pending_tasks.append(task) + self.total_submitted += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + task.done_event.wait() + if task.error is not None: + raise task.error + assert task.result_feature is not None + return task.result_feature, dict(task.profile) + + def snapshot(self) -> Dict[str, int]: + with self.condition: + return { + "pending": len(self.pending_tasks), + "pending_peak": self.pending_peak, + "total_submitted": self.total_submitted, + "total_finished": self.total_finished, + "total_batches": self.total_batches, + "active_batch_size": self.active_batch_size, + "active_batch_peak": self.active_batch_peak, + "batch_window_ms": int(self.batch_window_s * 1000.0), + "max_batch_items": self.max_batch_items, + "max_batch_tokens": self.max_batch_tokens, + } + + def _collect_batch(self) -> List[BertFeatureTask]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + + batch: List[BertFeatureTask] = [self.pending_tasks.popleft()] + batch_tokens = self._estimate_task_tokens(batch[0]) + deadline = time.perf_counter() + self.batch_window_s + + while len(batch) < self.max_batch_items: + remaining = deadline - time.perf_counter() + if remaining <= 0: + break + if not self.pending_tasks: + self.condition.wait(timeout=remaining) + continue + next_task = self.pending_tasks[0] + next_tokens = self._estimate_task_tokens(next_task) + if len(batch) >= self.max_batch_items or (batch_tokens + next_tokens) > self.max_batch_tokens: + break + batch.append(self.pending_tasks.popleft()) + batch_tokens += next_tokens + + self.active_batch_size = len(batch) + if self.active_batch_size > self.active_batch_peak: + self.active_batch_peak = self.active_batch_size + return batch + + def _finalize_batch(self, batch: List[BertFeatureTask]) -> None: + with self.condition: + self.active_batch_size = 0 + self.total_batches += 1 + self.total_finished += len(batch) + + def _run_batch(self, batch: List[BertFeatureTask]) -> None: + batch_started = time.perf_counter() + texts = [task.norm_text for task in batch] + batch_tokens = sum(self._estimate_task_tokens(task) for task in batch) + + limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} + if self.stage_limiter is None: + tokenize_start = time.perf_counter() + inputs = self.tokenizer(texts, return_tensors="pt", padding=True) + tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0 + attention_mask_cpu = inputs["attention_mask"].cpu() + for key in inputs: + inputs[key] = inputs[key].to(self.device) + forward_start = time.perf_counter() + with torch.no_grad(): + outputs = self.bert_model(**inputs, output_hidden_states=True) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + else: + with self.stage_limiter.enter() as limiter_stats: + tokenize_start = time.perf_counter() + inputs = self.tokenizer(texts, return_tensors="pt", padding=True) + tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0 + attention_mask_cpu = inputs["attention_mask"].cpu() + for key in inputs: + inputs[key] = inputs[key].to(self.device) + forward_start = time.perf_counter() + with torch.no_grad(): + outputs = self.bert_model(**inputs, output_hidden_states=True) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + + hidden = outputs["hidden_states"][-3].detach().cpu() + scatter_start = time.perf_counter() + for batch_index, task in enumerate(batch): + try: + text_len = len(task.word2ph) + if text_len != len(task.norm_text): + raise AssertionError( + f"word2ph/text length mismatch: task={task.task_id} word2ph={text_len} text={len(task.norm_text)}" + ) + seq_len = int(attention_mask_cpu[batch_index].sum().item()) + char_features = hidden[batch_index, 1 : seq_len - 1] + if char_features.shape[0] != text_len: + raise AssertionError( + f"bert token length mismatch: task={task.task_id} token_len={char_features.shape[0]} text_len={text_len}" + ) + phone_level_feature = [] + for char_index, repeat_count in enumerate(task.word2ph): + phone_level_feature.append(char_features[char_index].repeat(repeat_count, 1)) + task.result_feature = torch.cat(phone_level_feature, dim=0).T + task.profile = { + "bert_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]), + "bert_forward_ms": float(forward_ms), + "bert_tokenize_ms": float(tokenize_ms), + "bert_scatter_ms": 0.0, + "bert_calls": 1.0, + "bert_stage_slots": float(limiter_stats["slots"]), + "bert_stage_inflight_peak": float(limiter_stats["peak_inflight"]), + "bert_batch_size": float(len(batch)), + "bert_batch_tokens": float(batch_tokens), + } + except Exception as exc: # noqa: PERF203 + task.error = exc + scatter_ms = (time.perf_counter() - scatter_start) * 1000.0 + for task in batch: + if task.result_feature is not None: + task.profile["bert_scatter_ms"] = float(scatter_ms) + task.done_event.set() + + def _run_loop(self) -> None: + while True: + batch = self._collect_batch() + try: + self._run_batch(batch) + except Exception as exc: # noqa: PERF203 + for task in batch: + task.error = exc + task.done_event.set() + finally: + self._finalize_batch(batch) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py new file mode 100644 index 00000000..7a1f9a53 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py @@ -0,0 +1,262 @@ +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Deque, Dict, List, Tuple + +import librosa +import numpy as np +import torch + + +REF_AUDIO_MIN_SAMPLES_16K = 48000 +REF_AUDIO_MAX_SAMPLES_16K = 160000 + + +def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor: + wav_mono = raw_audio + if wav_mono.dim() == 2 and wav_mono.shape[0] != 1: + wav_mono = wav_mono.mean(0, keepdim=True) + wav16k = wav_mono.squeeze(0).cpu().numpy() + if raw_sr != 16000: + wav16k = librosa.resample(wav16k, orig_sr=raw_sr, target_sr=16000) + if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K: + raise OSError("参考音频在3~10秒范围外,请更换!") + wav16k = np.ascontiguousarray(wav16k, dtype=np.float32) + if zero_wav_samples > 0: + wav16k = np.concatenate([wav16k, np.zeros(int(zero_wav_samples), dtype=np.float32)], axis=0) + return torch.from_numpy(wav16k) + + +def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor: + if conv1d is None: + return input_lengths.to(dtype=torch.long) + kernel_size = int(conv1d.kernel_size[0]) + stride = int(conv1d.stride[0]) + padding = int(conv1d.padding[0]) + dilation = int(conv1d.dilation[0]) + output_lengths = torch.div( + input_lengths + 2 * padding - dilation * (kernel_size - 1) - 1, + stride, + rounding_mode="floor", + ) + 1 + return torch.clamp(output_lengths, min=0).to(dtype=torch.long) + + +@dataclass +class RefSemanticTask: + raw_audio: torch.Tensor + raw_sr: int + task_id: str = field(default_factory=lambda: uuid.uuid4().hex) + created_at: float = field(default_factory=time.perf_counter) + done_event: threading.Event = field(default_factory=threading.Event) + result_prompt_semantic: torch.Tensor | None = None + error: Exception | None = None + profile: Dict[str, float] = field(default_factory=dict) + + +class PrepareRefSemanticBatchWorker: + def __init__( + self, + ssl_model, + vits_model, + device, + is_half: bool, + zero_wav_samples: int, + stage_limiter=None, + batch_window_ms: int = 5, + max_batch_items: int = 8, + max_batch_samples: int = 960000, + ): + self.ssl_model = ssl_model + self.vits_model = vits_model + self.device = device + self.is_half = bool(is_half) + self.zero_wav_samples = max(0, int(zero_wav_samples)) + self.stage_limiter = stage_limiter + self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0) + self.max_batch_items = max(1, int(max_batch_items)) + self.max_batch_samples = max(REF_AUDIO_MIN_SAMPLES_16K + self.zero_wav_samples, int(max_batch_samples)) + + self.condition = threading.Condition() + self.pending_tasks: Deque[RefSemanticTask] = deque() + self.pending_peak = 0 + self.total_submitted = 0 + self.total_finished = 0 + self.total_batches = 0 + self.active_batch_size = 0 + self.active_batch_peak = 0 + self.active_batch_samples = 0 + self.active_batch_samples_peak = 0 + self.worker_thread = threading.Thread( + target=self._run_loop, + name="prepare-ref-semantic-batch-worker", + daemon=True, + ) + self.worker_thread.start() + + def _estimate_task_samples(self, task: RefSemanticTask) -> int: + raw_len = int(task.raw_audio.shape[-1]) if task.raw_audio.dim() > 0 else 0 + base = int(round(raw_len * 16000.0 / max(1, int(task.raw_sr)))) + return max(REF_AUDIO_MIN_SAMPLES_16K, base) + self.zero_wav_samples + + def submit(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]: + task = RefSemanticTask(raw_audio=raw_audio, raw_sr=int(raw_sr)) + with self.condition: + self.pending_tasks.append(task) + self.total_submitted += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + task.done_event.wait() + if task.error is not None: + raise task.error + assert task.result_prompt_semantic is not None + return task.result_prompt_semantic, dict(task.profile) + + def snapshot(self) -> Dict[str, int]: + with self.condition: + return { + "pending": len(self.pending_tasks), + "pending_peak": self.pending_peak, + "total_submitted": self.total_submitted, + "total_finished": self.total_finished, + "total_batches": self.total_batches, + "active_batch_size": self.active_batch_size, + "active_batch_peak": self.active_batch_peak, + "active_batch_samples": self.active_batch_samples, + "active_batch_samples_peak": self.active_batch_samples_peak, + "batch_window_ms": int(self.batch_window_s * 1000.0), + "max_batch_items": self.max_batch_items, + "max_batch_samples": self.max_batch_samples, + } + + def _collect_batch(self) -> List[RefSemanticTask]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + + batch: List[RefSemanticTask] = [self.pending_tasks.popleft()] + batch_samples = self._estimate_task_samples(batch[0]) + deadline = time.perf_counter() + self.batch_window_s + + while len(batch) < self.max_batch_items: + remaining = deadline - time.perf_counter() + if remaining <= 0: + break + if not self.pending_tasks: + self.condition.wait(timeout=remaining) + continue + next_task = self.pending_tasks[0] + next_samples = self._estimate_task_samples(next_task) + if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples: + break + batch.append(self.pending_tasks.popleft()) + batch_samples += next_samples + + self.active_batch_size = len(batch) + self.active_batch_samples = batch_samples + if self.active_batch_size > self.active_batch_peak: + self.active_batch_peak = self.active_batch_size + if self.active_batch_samples > self.active_batch_samples_peak: + self.active_batch_samples_peak = self.active_batch_samples + return batch + + def _finalize_batch(self, batch: List[RefSemanticTask]) -> None: + with self.condition: + self.active_batch_size = 0 + self.active_batch_samples = 0 + self.total_batches += 1 + self.total_finished += len(batch) + + def _get_hidden_lengths(self, attention_mask: torch.Tensor, hidden_length: int) -> torch.Tensor: + model = self.ssl_model.model + if hasattr(model, "_get_feature_vector_attention_mask"): + feature_mask = model._get_feature_vector_attention_mask(hidden_length, attention_mask) + return feature_mask.to(dtype=torch.long).sum(dim=1) + raw_lengths = attention_mask.to(dtype=torch.long).sum(dim=1) + if hasattr(model, "_get_feat_extract_output_lengths"): + return model._get_feat_extract_output_lengths(raw_lengths).to(dtype=torch.long) + return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device) + + @torch.inference_mode() + def _run_batch(self, batch: List[RefSemanticTask]) -> None: + batch_started = time.perf_counter() + prepared_start = time.perf_counter() + prepared_wavs = [ + prepare_prompt_semantic_wav16k(task.raw_audio, int(task.raw_sr), self.zero_wav_samples) for task in batch + ] + cpu_prepare_ms = (time.perf_counter() - prepared_start) * 1000.0 + wav_lengths = torch.tensor([int(wav.shape[0]) for wav in prepared_wavs], dtype=torch.long) + batch_samples = int(wav_lengths.sum().item()) + max_wav_len = int(wav_lengths.max().item()) + + input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32) + attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long) + for batch_index, wav in enumerate(prepared_wavs): + wav_len = int(wav.shape[0]) + input_values_cpu[batch_index, :wav_len] = wav + attention_mask_cpu[batch_index, :wav_len] = 1 + + limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} + if self.stage_limiter is None: + input_values = input_values_cpu.to(self.device) + attention_mask = attention_mask_cpu.to(self.device) + if self.is_half: + input_values = input_values.half() + forward_start = time.perf_counter() + outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) + hubert_feature = outputs["last_hidden_state"].transpose(1, 2) + hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) + codes = self.vits_model.extract_latent(hubert_feature) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + else: + with self.stage_limiter.enter() as limiter_stats: + input_values = input_values_cpu.to(self.device) + attention_mask = attention_mask_cpu.to(self.device) + if self.is_half: + input_values = input_values.half() + forward_start = time.perf_counter() + outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) + hubert_feature = outputs["last_hidden_state"].transpose(1, 2) + hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) + codes = self.vits_model.extract_latent(hubert_feature) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + + code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None)) + scatter_start = time.perf_counter() + for batch_index, task in enumerate(batch): + try: + code_len = int(code_lengths[batch_index].item()) + task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone() + task.profile = { + "prompt_semantic_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]), + "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), + "prompt_semantic_forward_ms": float(forward_ms), + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_calls": 1.0, + "prompt_semantic_stage_slots": float(limiter_stats["slots"]), + "prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]), + "prompt_semantic_batch_size": float(len(batch)), + "prompt_semantic_batch_samples": float(batch_samples), + } + except Exception as exc: # noqa: PERF203 + task.error = exc + scatter_ms = (time.perf_counter() - scatter_start) * 1000.0 + for task in batch: + if task.result_prompt_semantic is not None: + task.profile["prompt_semantic_scatter_ms"] = float(scatter_ms) + task.done_event.set() + + def _run_loop(self) -> None: + while True: + batch = self._collect_batch() + try: + self._run_batch(batch) + except Exception as exc: # noqa: PERF203 + for task in batch: + task.error = exc + task.done_event.set() + finally: + self._finalize_batch(batch) diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index c8643991..de498573 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -1,5 +1,6 @@ from __future__ import annotations +from concurrent.futures import Future from dataclasses import dataclass from pathlib import Path import time @@ -123,31 +124,58 @@ def prepare_request_state( prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) text = spec.text.strip("\n") - _sync_device(device) - prompt_text_features_start = time.perf_counter() - prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features(prompt_text, spec.prompt_lang) - _sync_device(device) - prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0 + prompt_text_profile: Dict[str, float] = {} + text_features_profile: Dict[str, float] = {} + text_feature_pair_start = time.perf_counter() + prompt_future: Future | None = None + + def _extract_prompt_features(): + _sync_device(device) + prompt_start = time.perf_counter() + result = tts.extract_text_features(prompt_text, spec.prompt_lang, profile=prompt_text_profile) + _sync_device(device) + return result, (time.perf_counter() - prompt_start) * 1000.0 + + if getattr(tts, "prepare_text_cpu_executor", None) is not None: + prompt_future = tts.prepare_text_cpu_executor.submit(_extract_prompt_features) _sync_device(device) text_features_start = time.perf_counter() - phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang) + phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang, profile=text_features_profile) _sync_device(device) text_features_ms = (time.perf_counter() - text_features_start) * 1000.0 + + if prompt_future is None: + _sync_device(device) + prompt_text_features_start = time.perf_counter() + prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features( + prompt_text, spec.prompt_lang, profile=prompt_text_profile + ) + _sync_device(device) + prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0 + prompt_text_profile["parallel_future_wait_ms"] = 0.0 + else: + prompt_wait_start = time.perf_counter() + (prompt_phones, prompt_bert_features, prompt_norm_text), prompt_text_features_ms = prompt_future.result() + prompt_text_profile["parallel_future_wait_ms"] = (time.perf_counter() - prompt_wait_start) * 1000.0 + + text_feature_pair_ms = (time.perf_counter() - text_feature_pair_start) * 1000.0 if phones is None: raise ValueError(f"{spec.request_id} text preprocessing returned no phones") _sync_device(device) - prompt_semantic_start = time.perf_counter() - prompt_semantic = tts.extract_prompt_semantic(str(spec.ref_audio_path)).long() + ref_audio_bundle_start = time.perf_counter() + ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path)) + prompt_semantic = ref_audio_bundle["prompt_semantic"].long() + spec_audio, audio_16k = ref_audio_bundle["refer_spec"] + raw_audio = ref_audio_bundle["raw_audio"] + raw_sr = int(ref_audio_bundle["raw_sr"]) _sync_device(device) - prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 - - _sync_device(device) - ref_spec_start = time.perf_counter() - spec_audio, audio_16k, raw_audio, raw_sr = tts.extract_ref_spec(str(spec.ref_audio_path)) - _sync_device(device) - ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + ref_audio_bundle_ms = (time.perf_counter() - ref_audio_bundle_start) * 1000.0 + bundle_profile = ref_audio_bundle.get("profile", {}) + prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms)) + ref_spec_ms = float(bundle_profile.get("ref_spec_ms", 0.0)) + audio_load_ms = float(bundle_profile.get("audio_load_ms", 0.0)) _sync_device(device) tensorize_start = time.perf_counter() @@ -164,8 +192,43 @@ def prepare_request_state( prepare_profile = { "prompt_text_features_ms": prompt_text_features_ms, "text_features_ms": text_features_ms, + "prompt_text_bert_wait_ms": float(prompt_text_profile.get("bert_wait_ms", 0.0)), + "prompt_text_bert_forward_ms": float(prompt_text_profile.get("bert_forward_ms", 0.0)), + "prompt_text_bert_tokenize_ms": float(prompt_text_profile.get("bert_tokenize_ms", 0.0)), + "prompt_text_bert_scatter_ms": float(prompt_text_profile.get("bert_scatter_ms", 0.0)), + "prompt_text_bert_calls": float(prompt_text_profile.get("bert_calls", 0.0)), + "prompt_text_bert_stage_slots": float(prompt_text_profile.get("bert_stage_slots", 0.0)), + "prompt_text_bert_stage_inflight_peak": float(prompt_text_profile.get("bert_stage_inflight_peak", 0.0)), + "prompt_text_bert_batch_size_peak": float(prompt_text_profile.get("bert_batch_size_peak", 0.0)), + "prompt_text_bert_batch_tokens_peak": float(prompt_text_profile.get("bert_batch_tokens_peak", 0.0)), + "prompt_text_parallel_future_wait_ms": float(prompt_text_profile.get("parallel_future_wait_ms", 0.0)), + "text_bert_wait_ms": float(text_features_profile.get("bert_wait_ms", 0.0)), + "text_bert_forward_ms": float(text_features_profile.get("bert_forward_ms", 0.0)), + "text_bert_tokenize_ms": float(text_features_profile.get("bert_tokenize_ms", 0.0)), + "text_bert_scatter_ms": float(text_features_profile.get("bert_scatter_ms", 0.0)), + "text_bert_calls": float(text_features_profile.get("bert_calls", 0.0)), + "text_bert_stage_slots": float(text_features_profile.get("bert_stage_slots", 0.0)), + "text_bert_stage_inflight_peak": float(text_features_profile.get("bert_stage_inflight_peak", 0.0)), + "text_bert_batch_size_peak": float(text_features_profile.get("bert_batch_size_peak", 0.0)), + "text_bert_batch_tokens_peak": float(text_features_profile.get("bert_batch_tokens_peak", 0.0)), + "text_feature_pair_ms": text_feature_pair_ms, + "text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)), + "audio_load_ms": audio_load_ms, + "audio_stage_wait_ms": float(bundle_profile.get("audio_stage_wait_ms", 0.0)), + "audio_stage_slots": float(bundle_profile.get("audio_stage_slots", 0.0)), + "audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)), "prompt_semantic_ms": prompt_semantic_ms, + "prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(bundle_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float(bundle_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + "prompt_semantic_batch_size": float(bundle_profile.get("prompt_semantic_batch_size", 0.0)), + "prompt_semantic_batch_samples": float(bundle_profile.get("prompt_semantic_batch_samples", 0.0)), + "ref_spec_wait_ms": float(bundle_profile.get("ref_spec_wait_ms", 0.0)), "ref_spec_ms": ref_spec_ms, + "ref_audio_bundle_ms": ref_audio_bundle_ms, "tensorize_ms": tensorize_ms, "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, "wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0, @@ -186,7 +249,7 @@ def prepare_request_state( prompt_semantic=prompt_semantic, refer_spec=(spec_audio, audio_16k), raw_audio=raw_audio, - raw_sr=int(raw_sr), + raw_sr=raw_sr, top_k=spec.top_k, top_p=spec.top_p, temperature=spec.temperature, @@ -409,9 +472,18 @@ def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor: return F.pad(mask, (pad_len, 0), value=True) +def _fit_decode_mask_length(mask: torch.Tensor, target_len: int) -> torch.Tensor: + if mask.shape[-1] > target_len: + return mask[:, :, :, -target_len:] + if mask.shape[-1] < target_len: + return _pad_decode_mask_left(mask, target_len) + return mask + + def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> torch.Tensor: + expected_mask_len = running_request.k_cache[0].shape[1] + 1 if running_request.decode_attn_mask is not None: - return running_request.decode_attn_mask + return _fit_decode_mask_length(running_request.decode_attn_mask, expected_mask_len) current_mask_len = running_request.k_cache[0].shape[1] + 1 return torch.zeros( (1, 1, 1, current_mask_len), @@ -481,17 +553,19 @@ def run_prefill_step( real_kv_len = int(active_batch.x_lens[batch_index].item()) + prefix_len request_k_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in k_cache] request_v_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in v_cache] + request_decode_attn_mask = None + if decode_attn_mask is not None: + request_decode_attn_mask = decode_attn_mask[batch_index : batch_index + 1].clone() + request_decode_attn_mask = _fit_decode_mask_length(request_decode_attn_mask, real_kv_len + 1) + if not request_decode_attn_mask.any().item(): + request_decode_attn_mask = None running_requests.append( T2SRunningRequest( state=state, y_sequence=new_history, prefix_len=prefix_len, - decode_attn_mask=( - None - if decode_attn_mask is None - else decode_attn_mask[batch_index : batch_index + 1].clone() - ), + decode_attn_mask=request_decode_attn_mask, k_cache=request_k_cache, v_cache=request_v_cache, step_idx=1, @@ -603,6 +677,9 @@ def run_decode_step_for_running( batch_index : batch_index + 1, :, :, -current_decode_mask_len: ] next_decode_attn_mask = F.pad(current_decode_attn_mask, (0, 1), value=False) + next_decode_attn_mask = _fit_decode_mask_length(next_decode_attn_mask, real_next_kv_len + 1) + if not next_decode_attn_mask.any().item(): + next_decode_attn_mask = None next_running.append( T2SRunningRequest( state=running_request.state, diff --git a/api_v3.py b/api_v3.py index 9d250119..92f9a3b9 100644 --- a/api_v3.py +++ b/api_v3.py @@ -261,8 +261,9 @@ class SchedulerDebugWorker: self.tts = tts self.max_steps = max_steps self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0 - self.prepare_lock = threading.Lock() self.condition = threading.Condition() + self.prepare_inflight = 0 + self.prepare_peak_inflight = 0 self.pending_jobs: List[SchedulerPendingJob] = [] self.running_requests: List[T2SRunningRequest] = [] self.job_map: dict[str, SchedulerPendingJob] = {} @@ -282,8 +283,20 @@ class SchedulerDebugWorker: pass def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState: - with self.prepare_lock: - return prepare_request_state(self.tts, spec) + with self.condition: + self.prepare_inflight += 1 + prepare_inflight_on_enter = self.prepare_inflight + if self.prepare_inflight > self.prepare_peak_inflight: + self.prepare_peak_inflight = self.prepare_inflight + prepare_peak_inflight = self.prepare_peak_inflight + try: + state = prepare_request_state(self.tts, spec) + state.prepare_profile["worker_prepare_inflight_on_enter"] = float(prepare_inflight_on_enter) + state.prepare_profile["worker_prepare_peak_inflight"] = float(prepare_peak_inflight) + return state + finally: + with self.condition: + self.prepare_inflight = max(0, self.prepare_inflight - 1) def submit( self, @@ -363,9 +376,28 @@ class SchedulerDebugWorker: def get_state(self) -> dict: with self.condition: + bert_stage = self.tts.prepare_bert_stage_limiter.snapshot() + ref_audio_stage = self.tts.prepare_ref_audio_stage_limiter.snapshot() + bert_batch_worker = ( + None + if self.tts.prepare_bert_batch_worker is None + else self.tts.prepare_bert_batch_worker.snapshot() + ) + ref_semantic_batch_worker = ( + None + if self.tts.prepare_ref_semantic_batch_worker is None + else self.tts.prepare_ref_semantic_batch_worker.snapshot() + ) return { "pending_jobs": len(self.pending_jobs), "running_requests": len(self.running_requests), + "prepare_inflight": self.prepare_inflight, + "prepare_peak_inflight": self.prepare_peak_inflight, + "prepare_text_cpu_workers": int(getattr(self.tts, "prepare_text_cpu_workers", 0)), + "prepare_bert_stage": bert_stage, + "prepare_bert_batch_worker": bert_batch_worker, + "prepare_ref_audio_stage": ref_audio_stage, + "prepare_ref_semantic_batch_worker": ref_semantic_batch_worker, "tracked_jobs": len(self.job_map), "total_submitted": self.total_submitted, "total_finished": self.total_finished, @@ -907,10 +939,36 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): { "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Batch-Size-Peak": str( + int(prepare_profile.get("prompt_text_bert_batch_size_peak", 0.0)) + ), + "X-Prepare-Target-Bert-Batch-Size-Peak": str( + int(prepare_profile.get("text_bert_batch_size_peak", 0.0)) + ), + "X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}", + "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), + "X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}", + "X-Prepare-Audio-Stage-Wait-Ms": f"{float(prepare_profile.get('audio_stage_wait_ms', 0.0)):.3f}", "X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Wait-Ms": f"{float(prepare_profile.get('prompt_semantic_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-CPU-Ms": f"{float(prepare_profile.get('prompt_semantic_cpu_prepare_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Forward-Ms": f"{float(prepare_profile.get('prompt_semantic_forward_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Batch-Size": str( + int(prepare_profile.get("prompt_semantic_batch_size", 0.0)) + ), "X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}", + "X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}", + "X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}", "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", "X-Prepare-Profile-Wall-Ms": f"{float(prepare_profile.get('wall_total_ms', 0.0)):.3f}", + "X-Prepare-Inflight-On-Enter": str( + int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0)) + ), + "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), } ) return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers)